In [1]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import timm
import torchsummary
import imageio
import numpy as np
import matplotlib.pyplot as plt
import os
from torch.cuda.amp import GradScaler, autocast

from tqdm.notebook import tqdm

import sys
sys.path.append('../DataLoader')

from dataloader import SunImageDataset

from torch.func import vmap

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

cuda


In [3]:
# Hyper-parameters
input_size = 224*224
hidden_size = 166
num_epochs = 10
batch_size = 5
learning_rate = 0.001
# dropout = 0.6990787087509548

In [4]:
dataset = SunImageDataset(csv_file='D:/Dissertation/pytorch trial/dataset.csv', offset=0, transform=transforms.ToTensor())
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
trainset, testset = torch.utils.data.Subset(dataset, range(train_size)), torch.utils.data.Subset(dataset, range(train_size, len(dataset)))

trainloader = torch.utils.data.DataLoader(dataset=trainset, batch_size=batch_size, shuffle=True)

# # Get a batch of training data
# dataiter = iter(trainloader)
# images, labels = next(dataiter)

# print(images)
# print(labels)

In [None]:
class GmiSwinTransformer(nn.Module):
    def __init__(self, hidden_size: int):
        super(GmiSwinTransformer, self).__init__()
        
        # Batch normalization for 3 channels
        self.bn = nn.BatchNorm2d(3)
        
        # Initialize Swin Transformer
        self.pretrained_model = timm.create_model(
            'swin_base_patch4_window7_224',
            pretrained=True,
            num_classes=hidden_size
        )

        
        # Fully connected layers
        self.fc = nn.Sequential(
            nn.LeakyReLU(),
            nn.Linear(hidden_size*10, hidden_size),
            nn.Dropout(p=0.5),  # Added dropout probability
            nn.LeakyReLU(),
            nn.Linear(hidden_size, 1),
            nn.LeakyReLU()
        )
    
    def signle_image_feature_extractor(self, image):
        normalized_image = self.bn(image)
        return self.pretrained_model(normalized_image)

    def gray_to_rgb(self, x):
        """Convert grayscale to RGB by replicating channels"""
        return torch.cat([x, x, x], dim=1)  # [B, 3, H, W]
    
    def forward(self, images) -> torch.Tensor:
        """
        Batch should be in format:
        {
            'images': torch.FloatTensor((10, 1, 224, 224))
        }
        """
        # # Print input shape for debugging
        # # print("Input shape:", images.shape)
        # image_features = torch.zeros(images.shape[0],images.shape[1], hidden_size).to(device)
        # for i in range(images.shape[0]):
        #     image = images[i, :, :, :]
        #     # Pretrained swin transformer accepts three channel images
        #     three_channel = torch.stack([image,image,image], dim=1).squeeze(2)
        #     # print("three_channel", three_channel.size())
        #     # Model learns optimal initial normalisation
        #     normalized_images = self.bn(three_channel)
        #     # Get image features
        #     image_feature = self.pretrained_model.forward(normalized_images)
        #     image_features[i] = image_feature
        # # print("image_features before reshaping", image_features.size())
        # image_features = image_features.view(image_features.size(0), -1)
        # # print("image_features after reshaping", image_features.size())
        
        batch_size = images.shape[0]
        
        images = images.reshape(-1, 1, 224, 224)
        images = torch.cat([images, images, images], dim=1)
        normalized_images = self.bn(images)
        features = self.pretrained_model(normalized_images)

        image_features = features.view(batch_size, -1)
        
        output = self.fc(image_features)
        return output

# Initialize model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GmiSwinTransformer(hidden_size=hidden_size).to(device)

print(torchsummary.summary(model, (10, 1, 224, 224)))

# Loss and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

Layer (type:depth-idx)                             Output Shape              Param #
├─BatchNorm2d: 1-1                                 [-1, 3, 224, 224]         6
├─SwinTransformer: 1                               []                        --
|    └─PatchEmbed: 2-1                             [-1, 56, 56, 128]         --
|    |    └─Conv2d: 3-1                            [-1, 128, 56, 56]         6,272
|    |    └─LayerNorm: 3-2                         [-1, 56, 56, 128]         256
|    └─Sequential: 2-2                             [-1, 7, 7, 1024]          --
|    |    └─SwinTransformerStage: 3-3              [-1, 56, 56, 128]         397,896
|    |    └─SwinTransformerStage: 3-4              [-1, 28, 28, 256]         1,714,320
|    |    └─SwinTransformerStage: 3-5              [-1, 14, 14, 512]         57,317,920
|    |    └─SwinTransformerStage: 3-6              [-1, 7, 7, 1024]          27,304,512
|    └─LayerNorm: 2-3                              [-1, 7, 7, 1024]          2,048
|

In [9]:
torch.cuda.empty_cache()

In [28]:
# Training the model
n_total_steps = len(trainloader)
for epoch in range(num_epochs):
    for i, (images, labels) in tqdm(enumerate(trainloader), desc="Training Progress", total=len(trainloader)):
        # Move images and labels to device
        images = torch.stack(images).to(device)
        # print("images before reshaping", images.size())
        images = images.permute(1, 0, 2, 3, 4)  # Change shape to [5, 10, 1, 224, 224]
        # print("images after reshaping", images.size())
        labels = labels.to(device)
        
        # Forward pass with autograd
        with torch.amp.autocast('cuda'):
            outputs = model(images)
            loss = criterion(outputs, labels)
            print("Loss", loss)
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i+1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{n_total_steps}], Loss: {loss.item():.4f}')

Training Progress:   0%|          | 0/1519 [00:00<?, ?it/s]

  return F.mse_loss(input, target, reduction=self.reduction)


Loss tensor(6.6071, device='cuda:0', grad_fn=<MseLossBackward0>)
Loss tensor(1.4254, device='cuda:0', grad_fn=<MseLossBackward0>)
Loss tensor(24.5684, device='cuda:0', grad_fn=<MseLossBackward0>)
Loss tensor(3.2818, device='cuda:0', grad_fn=<MseLossBackward0>)
Loss tensor(5.0480, device='cuda:0', grad_fn=<MseLossBackward0>)
Loss tensor(1.6216, device='cuda:0', grad_fn=<MseLossBackward0>)
Loss tensor(4.4438, device='cuda:0', grad_fn=<MseLossBackward0>)
Loss tensor(4.3636, device='cuda:0', grad_fn=<MseLossBackward0>)
Loss tensor(4.7433, device='cuda:0', grad_fn=<MseLossBackward0>)
Loss tensor(2.6246, device='cuda:0', grad_fn=<MseLossBackward0>)
Loss tensor(5.8304, device='cuda:0', grad_fn=<MseLossBackward0>)
Loss tensor(2.1191, device='cuda:0', grad_fn=<MseLossBackward0>)
Loss tensor(0.2709, device='cuda:0', grad_fn=<MseLossBackward0>)
Loss tensor(3.8496, device='cuda:0', grad_fn=<MseLossBackward0>)
Loss tensor(5.2738, device='cuda:0', grad_fn=<MseLossBackward0>)
Loss tensor(4.3387, devi

KeyboardInterrupt: 

In [31]:
# Test the model
model.eval()
testloader = torch.utils.data.DataLoader(dataset=testset, batch_size=batch_size, shuffle=False)
test_losses = []

with torch.no_grad():
    for images, labels in tqdm(testloader, desc="Testing Progress"):
        images = torch.stack(images).to(device)
        images = images.permute(1, 0, 2, 3, 4)
        labels = labels.to(device)
        with torch.amp.autocast('cuda'):
            outputs = model(images)
            loss = criterion(outputs, labels)
        test_losses.append(loss.item())
        print("Test loss", test_losses)

avg_test_loss = sum(test_losses) / len(test_losses)
print(f'Average test loss: {avg_test_loss:.4f}')

Testing Progress:   0%|          | 0/380 [00:00<?, ?it/s]

  return F.mse_loss(input, target, reduction=self.reduction)


Test loss [2.029275417327881]
Test loss [2.029275417327881, 3.117931365966797]
Test loss [2.029275417327881, 3.117931365966797, 4.787830352783203]
Test loss [2.029275417327881, 3.117931365966797, 4.787830352783203, 1.1730252504348755]
Test loss [2.029275417327881, 3.117931365966797, 4.787830352783203, 1.1730252504348755, 1.936651587486267]
Test loss [2.029275417327881, 3.117931365966797, 4.787830352783203, 1.1730252504348755, 1.936651587486267, 1.4046493768692017]
Test loss [2.029275417327881, 3.117931365966797, 4.787830352783203, 1.1730252504348755, 1.936651587486267, 1.4046493768692017, 1.4569870233535767]
Test loss [2.029275417327881, 3.117931365966797, 4.787830352783203, 1.1730252504348755, 1.936651587486267, 1.4046493768692017, 1.4569870233535767, 0.587260365486145]
Test loss [2.029275417327881, 3.117931365966797, 4.787830352783203, 1.1730252504348755, 1.936651587486267, 1.4046493768692017, 1.4569870233535767, 0.587260365486145, 1.9755123853683472]
Test loss [2.029275417327881, 3.

KeyboardInterrupt: 