In [1]:

import torch
import torch.nn as nn
import torchvision.transforms as transforms
import timm
import torchsummary
import imageio
import numpy as np
import matplotlib.pyplot as plt
import os
from torch.cuda.amp import GradScaler, autocast

from tqdm.notebook import tqdm

import sys
sys.path.append('../DataLoader')

from dataloader import SunImageDataset

from torch.func import stack_module_state
from torch.func import vmap

from lightning.fabric import Fabric

In [2]:
torch.set_float32_matmul_precision('medium')
fabric = Fabric(accelerator='cuda', devices=1, precision="bf16-mixed")
fabric.launch()
print(fabric.device)

Using bfloat16 Automatic Mixed Precision (AMP)


cuda:0


In [None]:
# Hyper-parameters
input_size = 224*224
hidden_size = 166
num_epochs = 100
batch_size = 2
learning_rate = 0.001
# dropout = 0.6990787087509548

In [4]:
dataset = SunImageDataset(csv_file="D:\\dataset.csv", offset=0, transform=transforms.ToTensor())
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
trainset, testset = torch.utils.data.Subset(dataset, range(train_size)), torch.utils.data.Subset(dataset, range(train_size, len(dataset)))

trainloader = torch.utils.data.DataLoader(dataset=trainset, batch_size=batch_size, shuffle=True)
trainloader = fabric.setup_dataloaders(trainloader)


# # Get a batch of training data
# dataiter = iter(trainloader)
# images, labels = next(dataiter)
# images = torch.stack(images)
# print(images.shape)
# print(labels.shape)

# print(images)
# print(labels)

In [5]:

class GmiSwinTransformer(nn.Module):
    def __init__(self, hidden_size: int):
        super(GmiSwinTransformer, self).__init__()
        
        # Batch normalization for 3 channels
        self.bn = nn.BatchNorm2d(3)
        
        # Initialize Swin Transformer
        self.pretrained_model = timm.create_model(
            'swin_base_patch4_window7_224',
            pretrained=True,
            num_classes=hidden_size
        )
        
        
        # Fully connected layers
        self.fc = nn.Sequential(
            nn.LeakyReLU(),
            nn.Linear(hidden_size*10, hidden_size),
            nn.Dropout(p=0.5),  # Added dropout probability
            nn.LeakyReLU(),
            nn.Linear(hidden_size, 1),
            nn.LeakyReLU()
        )
    
    def forward(self, images) -> torch.Tensor:
        """
        Batch should be in format:
        {
            'images': torch.FloatTensor((10, 1, 224, 224))
        }
        """
        # # Print input shape for debugging
        # # print("Input shape:", images.shape)
        # image_features = torch.zeros(images.shape[0],images.shape[1], hidden_size).to(device)
        # for i in range(images.shape[0]):
        #     image = images[i, :, :, :]
        #     # Pretrained swin transformer accepts three channel images
        #     three_channel = torch.stack([image,image,image], dim=1).squeeze(2)
        #     # print("three_channel", three_channel.size())
        #     # Model learns optimal initial normalisation
        #     normalized_images = self.bn(three_channel)
        #     # Get image features
        #     image_feature = self.pretrained_model.forward(normalized_images)
        #     image_features[i] = image_feature
        # # print("image_features before reshaping", image_features.size())
        # image_features = image_features.view(image_features.size(0), -1)
        # print("image_features after reshaping", image_features.size())
        
        batch_size = images.shape[0]
        
        images = images.reshape(-1, 1, 224, 224)
        images = torch.cat([images, images, images], dim=1)
        normalized_images = self.bn(images)
        features = self.pretrained_model(normalized_images)
        image_features = features.view(batch_size, -1)
        
        output = self.fc(image_features)
        return output

# Initialize model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# model = GmiSwinTransformer(hidden_size=hidden_size).to(device)
model = GmiSwinTransformer(hidden_size=hidden_size)

# print(torchsummary.summary(model, (10, 1, 224, 224)))

# Loss and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

model, optimizer = fabric.setup(model, optimizer)
model.train()

_FabricModule(
  (_forward_module): GmiSwinTransformer(
    (bn): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (pretrained_model): SwinTransformer(
      (patch_embed): PatchEmbed(
        (proj): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
        (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      )
      (layers): Sequential(
        (0): SwinTransformerStage(
          (downsample): Identity()
          (blocks): Sequential(
            (0): SwinTransformerBlock(
              (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
              (attn): WindowAttention(
                (qkv): Linear(in_features=128, out_features=384, bias=True)
                (attn_drop): Dropout(p=0.0, inplace=False)
                (proj): Linear(in_features=128, out_features=128, bias=True)
                (proj_drop): Dropout(p=0.0, inplace=False)
                (softmax): Softmax(dim=-1)
              )
              (drop_pa

In [6]:
torch.cuda.empty_cache()

In [None]:
# Training the model
n_total_steps = len(trainloader)
for epoch in range(num_epochs):
    train_losses = []
    
    for i, (images, labels) in tqdm(enumerate(trainloader), desc="Training Progress", total=len(trainloader)):
        # Move images and labels to device
        images = torch.stack(images).float()
        images = images.permute(1, 0, 2, 3, 4)  # Change shape to [5, 10, 1, 224, 224]
        labels = labels.float()

        # Forward pass with autograd
        outputs = model(images).squeeze(1)
        loss = criterion(outputs, labels)
        tqdm.write(f"Epoch: {epoch+1}, Index: {i}, Loss: {loss.item():.4f}")

        # Backward pass and optimization
        optimizer.zero_grad()
        fabric.backward(loss)
        optimizer.step()
        # Store the loss
        train_losses.append(loss.item())
    print(f'Epoch [{epoch+1}], Loss: {loss.item():.4f}')
    plt.plot(train_losses, label='Training loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()
    

Training Progress:   0%|          | 0/205 [00:00<?, ?it/s]

Epoch: 1, Index: 0, Loss: 3.4264
Epoch: 1, Index: 1, Loss: 3.8210
Epoch: 1, Index: 2, Loss: 76.0761
Epoch: 1, Index: 3, Loss: 3.2861
Epoch: 1, Index: 4, Loss: 2.0356
Epoch: 1, Index: 5, Loss: 0.3006
Epoch: 1, Index: 6, Loss: 0.8847


In [None]:
# Test the model
model.eval()
testloader = torch.utils.data.DataLoader(dataset=testset, batch_size=batch_size, shuffle=False)
testloader = fabric.setup_dataloaders(testloader)
test_losses = []

with torch.no_grad():
    for images, labels in tqdm(testloader, desc="Testing Progress"):
        images = torch.stack(images).to(device).float()
        images = images.permute(1, 0, 2, 3, 4)
        labels = labels.to(device).float()
        with torch.amp.autocast('cuda'):
            outputs = model(images)
            loss = criterion(outputs, labels)
        test_losses.append(loss.item())
        print("Test loss", test_losses)

avg_test_loss = sum(test_losses) / len(test_losses)
print(f'Average test loss: {avg_test_loss:.4f}')