In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from src.process import crop_center,get_nine_crops
import numpy as np
import random
import os
from PIL import Image

class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        
        # Convolutional layers
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1)
        self.relu1 = nn.ReLU()
        self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        self.relu2 = nn.ReLU()
        self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)
        self.relu3 = nn.ReLU()
        
        # Fully connected layers
        self.fc1 = nn.Linear(64 * 8 * 8, 512)  # Map to a 1000-dimensional vector
        self.relu4 = nn.ReLU()

    def forward(self, x):
        x = self.maxpool1(self.relu1(self.conv1(x)))
        x = self.maxpool2(self.relu2(self.conv2(x)))
        # x = self.relu3(self.conv3(x))
        
        x = x.view(x.size(0), -1)  # Flatten the tensor
        x = self.relu4(self.fc1(x))
        return x
    
class SelfSupervisedModel(nn.Module):
    def __init__(self, siamese_deg=None,num_outputs=1000):
        super(SelfSupervisedModel, self).__init__()
        self.siamese_deg = siamese_deg
        self.patch_model = CNNModel()
        self.fc2 = nn.Linear(9 * 512, 4096)  # Concatenate the outputs from all patches
        self.fc3 = nn.Linear(4096, num_outputs)  # Output layer (can serve as self.output)

        # if siamese_deg == None:
        #     self.fc2 = nn.Linear(512,num_outputs)

    def forward(self, input_batch):
        # print('reach')
        #Downstream task layer
        if self.siamese_deg is None:
            # print(input_batch.shape)
            batch_features = self.patch_model(input_batch)
            # print("batch: ",batch_features.shape)
            x = self.fc2(batch_features)
            # print("last x: ",x.shape)
            x = F.log_softmax(x)
            return x
        # print('out')
        #self supervised learning
        batch_size, num_patches, channels, height, width = input_batch.size()

        
        final_feat_vectors = None
        
        for patch_ind in range(self.siamese_deg):
            # Each patch_batch would be of shape (batch_size, color_channels, h_patch, w_patch)
            patch_batch = input_batch[:, patch_ind, :, :, :]
            patch_batch_features = self.patch_model(patch_batch)

            if patch_ind == 0:
                final_feat_vectors = patch_batch_features
            else:
                final_feat_vectors = torch.cat([final_feat_vectors, patch_batch_features], dim=1)
        
        # Use fc3 as the output layer
        x = self.fc2(final_feat_vectors)
        x = F.log_softmax(self.fc3(x))
        # x = F.log_softmax(x)
        
        return x
    
class GetJigsawPuzzleDataset(Dataset):
    'Characterizes a dataset for PyTorch'
    def __init__(self, file_paths, avail_permuts_file_path, range_permut_indices=None, transform=None):
        'Initialization'
        self.file_paths = file_paths
        self.transform = transform
        self.permuts_avail = np.load(avail_permuts_file_path)
        self.range_permut_indices = range_permut_indices

    def __len__(self):
        'Denotes the total number of samples'
        return len(self.file_paths)

    def __getitem__(self, index):
        'Generates one sample of data'

        # Select sample
        file_path = self.file_paths[index]
        pil_image = Image.open(file_path)
        orig_image = pil_image
        # Check if image has only single channel. If True, then swap with 0th image
        # Assumption 0th image has got 3 number of channels
        if len(pil_image.getbands()) != 3:
            file_path = self.file_paths[0]
            pil_image = Image.open(file_path)

        # Convert image to torch tensor
        pil_image = pil_image.resize((128, 128))
        pil_image = crop_center(pil_image, 105, 105)

        # Get nine crops for the image
        nine_crops = get_nine_crops(pil_image)

        # Permut the 9 patches obtained from the image
        if self.range_permut_indices:
            permut_ind = random.randint(self.range_permut_indices[0], self.range_permut_indices[1])
        else:
            permut_ind = random.randint(0, len(self.permuts_avail) - 1)

        permutation_config = self.permuts_avail[permut_ind]

        permuted_patches_arr = [None] * 9
        for crop_new_pos, crop in zip(permutation_config, nine_crops):
            permuted_patches_arr[crop_new_pos] = crop

        # Apply data transforms
        tensor_patches = torch.zeros(9, 3, 32, 32)
        for ind, jigsaw_patch in enumerate(permuted_patches_arr):
            jigsaw_patch_tr = self.transform(jigsaw_patch)
            tensor_patches[ind] = jigsaw_patch_tr

        return tensor_patches, permut_ind
    
def get_paths():
    data_dir = 'CIFAR-10-images/test'
    file_paths_to_return = []
    
    for root, dirs, files in os.walk(data_dir):
        for file in files:
            if file.endswith(".jpg"):
                file_paths_to_return.append(root+'/'+file)                        
    
    return file_paths_to_return

# Jigsaw inference to check model performance on Test dataset

In [2]:
#for downstream classification
import argparse
import os

import torch
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from torch import nn, optim
from torch.utils.data import DataLoader, ConcatDataset,random_split
from torchvision.datasets import ImageFolder
from torchvision import transforms, utils, models
from src.trainer import ModelTrainTest
from torchvision import transforms

data_transform = transforms.Compose([
    transforms.RandomCrop((32, 32)),
    transforms.ColorJitter(brightness=[0.5, 1.5]),
    transforms.ToTensor(),
    # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

torch.manual_seed(42)

batch_size = 32

# Set device to use to gpu if available and declare model_file_path
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

test_file_path = get_paths()
permuts_file_path = 'Data/selected_permuts.npy'

test_loader = DataLoader(
    GetJigsawPuzzleDataset(test_file_path, permuts_file_path, transform=data_transform),
    batch_size=batch_size, shuffle=True, num_workers=8
)

# Train required model defined above on CUB200 data
num_classes = 10

model_to_test = SelfSupervisedModel(siamese_deg=9)

checkpoint_path = 'Model/jigsaw_solver_CIFAR-10_trained.pt'

checkpoint = torch.load(checkpoint_path)
print(checkpoint.keys())

new_checkpoint = {}

# Iterate through the keys in the loaded checkpoint
for key, value in checkpoint.items():
    # Modify the key to match the existing model's module name
    new_key = key.replace('module.', '')  # Remove 'module.' prefix if it exists
    new_checkpoint[new_key] = value

model_to_test.load_state_dict(new_checkpoint)

print('Model loaded successfully')

# Set device on which training is done. Plus optimizer to use.    
model_to_test.to(device)

odict_keys(['module.patch_model.conv1.weight', 'module.patch_model.conv1.bias', 'module.patch_model.conv2.weight', 'module.patch_model.conv2.bias', 'module.patch_model.conv3.weight', 'module.patch_model.conv3.bias', 'module.patch_model.fc1.weight', 'module.patch_model.fc1.bias', 'module.fc2.weight', 'module.fc2.bias', 'module.fc3.weight', 'module.fc3.bias'])
Model loaded successfully


SelfSupervisedModel(
  (patch_model): CNNModel(
    (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu1): ReLU()
    (maxpool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu2): ReLU()
    (maxpool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (conv3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu3): ReLU()
    (fc1): Linear(in_features=4096, out_features=512, bias=True)
    (relu4): ReLU()
  )
  (fc2): Linear(in_features=4608, out_features=4096, bias=True)
  (fc3): Linear(in_features=4096, out_features=1000, bias=True)
)

In [3]:
def count_correct_preds(network_output, target):

    output = network_output
    pred = output.data.max(1, keepdim=True)[1]  # get the index of the max log-probability
    pred.data = pred.data.view_as(target.data)
    correct = target.eq(pred).sum().item()

    return correct

def test(network,test_data_loader):
        network.eval()
        test_loss = 0
        correct = 0

        for batch_idx, (data, target) in enumerate(test_data_loader):
            data, target = data.to(device), target.to(device)
            output = network(data)
            test_loss += F.nll_loss(output, target, size_average=False).item()  # sum up batch loss

            correct += count_correct_preds(output, target)

            del data, target, output

        test_loss /= len(test_data_loader.dataset)
        test_acc = correct / len(test_data_loader.dataset)
        print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            test_loss, correct, len(test_data_loader.dataset),
            100. * correct / len(test_data_loader.dataset)))

In [4]:
test(network=model_to_test,test_data_loader = test_loader)

  x = F.log_softmax(self.fc3(x))



Test set: Average loss: 0.1855, Accuracy: 9533/10000 (95%)



# Jigsaw Downstream Task(Image classification) inference on Test dataset

In [5]:
from torchvision import transforms
normal_data_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

In [6]:
#for downstream classification
import argparse
import os

import torch
import numpy as np
import pandas as pd

from matplotlib import pyplot as plt

from torch import nn, optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, ConcatDataset,random_split
from torchvision.datasets import ImageFolder
from torchvision import transforms, utils, models
from src.trainer import ModelTrainTest

torch.manual_seed(42)


batch_size = 32

# Set device to use to gpu if available and declare model_file_path
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

test_dataset = ImageFolder(root='CIFAR-10-images/test',transform=normal_data_transform)

test_loader =  torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size,shuffle=False)

# Train required model defined above on CUB200 data
num_classes = 10

model_to_test = SelfSupervisedModel(siamese_deg=None)
model_to_test.fc2 = nn.Linear(512,num_classes)
checkpoint_path = 'Model/jigsaw_downstream_image_recognition.pt'

checkpoint = torch.load(checkpoint_path)
print(checkpoint.keys())

new_checkpoint = {}

# Iterate through the keys in the loaded checkpoint
for key, value in checkpoint.items():
    # Modify the key to match the existing model's module name
    new_key = key.replace('module.', '')  # Remove 'module.' prefix if it exists
    new_checkpoint[new_key] = value

model_to_test.load_state_dict(new_checkpoint)

# print(model_to_train)
# Load state dict for pre trained model weights
model_to_test = model_to_test.to(device)
print('Model loaded successfully')

odict_keys(['module.patch_model.conv1.weight', 'module.patch_model.conv1.bias', 'module.patch_model.conv2.weight', 'module.patch_model.conv2.bias', 'module.patch_model.conv3.weight', 'module.patch_model.conv3.bias', 'module.patch_model.fc1.weight', 'module.patch_model.fc1.bias', 'module.fc2.weight', 'module.fc2.bias', 'module.fc3.weight', 'module.fc3.bias'])
Model loaded successfully


In [7]:
def test_classification(network,test_data_loader):
        network.eval()
        test_loss = 0
        correct = 0

        for batch_idx, (data, target) in enumerate(test_data_loader):
            data, target = data.to(device), target.to(device)
            output = network(data)
            test_loss += F.nll_loss(output, target, size_average=False).item()  # sum up batch loss

            correct += count_correct_preds(output, target)

            del data, target, output

        test_loss /= len(test_data_loader.dataset)
        test_acc = correct / len(test_data_loader.dataset)
        print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            test_loss, correct, len(test_data_loader.dataset),
            100. * correct / len(test_data_loader.dataset)))

        return  test_loss, test_acc

In [8]:
test_classification(model_to_test,test_loader)

  x = F.log_softmax(x)



Test set: Average loss: 1.0222, Accuracy: 6662/10000 (67%)



(1.0222248842716217, 0.6662)