In [1]:
import argparse
import builtins
import math
import os
import random
import shutil
import time
import warnings
import numpy as np


import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.multiprocessing as mp
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
from torchvision.models import vgg16
from HSI_class import HSI
import createSample as CS
import augmentation as aug

best_acc1 = 0

In [2]:
model_names = sorted(name for name in models.__dict__
    if name.islower() and not name.startswith("__")
    and callable(models.__dict__[name]))

In [3]:
from torchvision.models import vgg16

class VGG16_HSI(nn.Module):
    def __init__(self, num_classes=2):
        super(VGG16_HSI, self).__init__()

         # Custom Convolutional Layer: Process 9x9x224 input
        self.pre_conv = nn.Sequential(
            nn.Conv2d(in_channels=224, out_channels=128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(256),
            nn.AdaptiveAvgPool2d((1, 1))  # Reduce to (256, 1, 1)
        )

        # Fully Connected Layer to reshape to (64, 56, 56)
        self.fc = nn.Linear(256 * 1 * 1, 64 * 56 * 56)

        # Load VGG-16 Model
        self.encoder = vgg16(pretrained=False)

        # Remove first VGG-16 conv layer
        self.encoder.features = nn.Sequential(*list(self.encoder.features.children())[1:])

        # Modify classifier to output 2 classes
        self.encoder.classifier[6] = nn.Linear(4096, 2)

    def forward(self, x):
        # print(f'before {x.shape}')
        x = self.pre_conv(x)  # Process hyperspectral input
        x = x.view(x.size(0), -1)  # Flatten

        # print(f'after preconv {x.shape}')
        x = self.fc(x)  # Fully connected layer
        # print(f'after fc {x.shape}')
        # Reshape to (batch_size, 64, 56, 56) before passing to VGG
        x = x.view(x.size(0), 64, 56, 56)
        # print(f'after reshape, before vgg second layer {x.shape}')

        x = self.encoder.features(x)  # Pass to VGG-16
        x = self.encoder.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.encoder.classifier(x)  # Final classification layer

        return x

In [4]:
gpu = 0

print("Use GPU: {} for training".format(gpu))

print("=> creating model")

model = VGG16_HSI()

print(model)




Use GPU: 0 for training
=> creating model




VGG16_HSI(
  (pre_conv): Sequential(
    (0): Conv2d(224, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
    (5): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): AdaptiveAvgPool2d(output_size=(1, 1))
  )
  (fc): Linear(in_features=256, out_features=200704, bias=True)
  (encoder): VGG(
    (features): Sequential(
      (0): ReLU(inplace=True)
      (1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (5): ReLU(inplace=True)
      (6): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (7): ReLU(inplace=Tr

In [5]:
# Freeze all layers except the last fully connected layer
for param in model.pre_conv.parameters():
    param.requires_grad = False  # Freeze convolutional layers
for param in model.fc.parameters():
    param.requires_grad = False  # Freeze convolutional layers
for param in model.encoder.features.parameters():
    param.requires_grad = False  # Freeze convolutional layers

for param in model.encoder.classifier[:-1].parameters():
    param.requires_grad = False  # Freeze all but the last FC layer

# Initialize the last FC layer
# Initialize the last FC layer
torch.nn.init.normal_(model.encoder.classifier[6].weight, mean=0.0, std=0.01)
torch.nn.init.zeros_(model.encoder.classifier[6].bias)

# Check which layers are trainable
for name, param in model.named_parameters():
    print(f"{name}: requires_grad={param.requires_grad}")

pre_conv.0.weight: requires_grad=False
pre_conv.0.bias: requires_grad=False
pre_conv.2.weight: requires_grad=False
pre_conv.2.bias: requires_grad=False
pre_conv.3.weight: requires_grad=False
pre_conv.3.bias: requires_grad=False
pre_conv.5.weight: requires_grad=False
pre_conv.5.bias: requires_grad=False
fc.weight: requires_grad=False
fc.bias: requires_grad=False
encoder.features.1.weight: requires_grad=False
encoder.features.1.bias: requires_grad=False
encoder.features.4.weight: requires_grad=False
encoder.features.4.bias: requires_grad=False
encoder.features.6.weight: requires_grad=False
encoder.features.6.bias: requires_grad=False
encoder.features.9.weight: requires_grad=False
encoder.features.9.bias: requires_grad=False
encoder.features.11.weight: requires_grad=False
encoder.features.11.bias: requires_grad=False
encoder.features.13.weight: requires_grad=False
encoder.features.13.bias: requires_grad=False
encoder.features.16.weight: requires_grad=False
encoder.features.16.bias: requir

In [6]:
pretrained = r'C:\Users\Asus TUF\Documents\code\TA\simsiam\simsiam\models\checkpoint_0009.pth.tar'

if pretrained:
    if os.path.isfile(pretrained):
        print("=> loading checkpoint '{}'".format(pretrained))
        checkpoint = torch.load(pretrained, map_location="cpu")

        # rename moco pre-trained keys
        state_dict = checkpoint['state_dict']

        print("Model state_dict keys:", model.state_dict().keys())  # Debugging
        print("Checkpoint state_dict keys:", state_dict.keys())  
        # for k in list(state_dict.keys()):
        #     print(f"Processing key: {k}")  # Debugging
        #     if k.startswith('module.encoder') and not k.startswith('module.encoder.fc'):
        #         state_dict[k[len("module.encoder."):]] = state_dict[k]
        #     del state_dict[k]

        # Remove the final classification layer from state_dict
        state_dict = {k: v for k, v in state_dict.items() if not k.startswith("encoder.classifier.6")}

        # Load the modified state_dict (ignoring the missing classification layer)
        msg = model.load_state_dict(state_dict, strict=False)

        # Check missing keys
        print("Missing keys:", msg.missing_keys)

        start_epoch = 0
        msg = model.load_state_dict(state_dict, strict=False)
  
        assert set(msg.missing_keys) == {"encoder.classifier.6.weight", "encoder.classifier.6.bias"}

        print("=> loaded pre-trained model '{}'".format(pretrained))
    else:
        print("=> no checkpoint found at '{}'".format(pretrained))



=> loading checkpoint 'C:\Users\Asus TUF\Documents\code\TA\simsiam\simsiam\models\checkpoint_0009.pth.tar'


  checkpoint = torch.load(pretrained, map_location="cpu")


Model state_dict keys: odict_keys(['pre_conv.0.weight', 'pre_conv.0.bias', 'pre_conv.2.weight', 'pre_conv.2.bias', 'pre_conv.2.running_mean', 'pre_conv.2.running_var', 'pre_conv.2.num_batches_tracked', 'pre_conv.3.weight', 'pre_conv.3.bias', 'pre_conv.5.weight', 'pre_conv.5.bias', 'pre_conv.5.running_mean', 'pre_conv.5.running_var', 'pre_conv.5.num_batches_tracked', 'fc.weight', 'fc.bias', 'encoder.features.1.weight', 'encoder.features.1.bias', 'encoder.features.4.weight', 'encoder.features.4.bias', 'encoder.features.6.weight', 'encoder.features.6.bias', 'encoder.features.9.weight', 'encoder.features.9.bias', 'encoder.features.11.weight', 'encoder.features.11.bias', 'encoder.features.13.weight', 'encoder.features.13.bias', 'encoder.features.16.weight', 'encoder.features.16.bias', 'encoder.features.18.weight', 'encoder.features.18.bias', 'encoder.features.20.weight', 'encoder.features.20.bias', 'encoder.features.23.weight', 'encoder.features.23.bias', 'encoder.features.25.weight', 'enco

In [7]:
print(model)

VGG16_HSI(
  (pre_conv): Sequential(
    (0): Conv2d(224, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
    (5): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): AdaptiveAvgPool2d(output_size=(1, 1))
  )
  (fc): Linear(in_features=256, out_features=200704, bias=True)
  (encoder): VGG(
    (features): Sequential(
      (0): ReLU(inplace=True)
      (1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (5): ReLU(inplace=True)
      (6): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (7): ReLU(inplace=Tr

In [8]:
dataset_path = r"C:\Users\Asus TUF\Documents\code\TA\Hyperspectral oil spill detection datasets"

dataset = []

i = 0
for filename in os.listdir(dataset_path):
    if i > 1:
        break
    file_path = os.path.join(dataset_path, filename)
    if os.path.isfile(file_path):  # Check if it's a file
        print(f"Processing file: {file_path}")
        hsi = HSI(file_path)
        dataset.append(hsi)
    i += 1
hsi_ = dataset[0]
patch_size = 9
sample_per_class = 5
selected_patch_0, selected_patch_1, random_indices_0, random_indices_1 = CS.createSample(hsi_, patch_size, sample_per_class)

i =0
half_patch = patch_size // 2
print(hsi_.img[random_indices_0[i][0]][random_indices_0[i][1]])
print(selected_patch_0[i][half_patch][half_patch])

print(hsi_.img[random_indices_1[i][0]][random_indices_1[i][1]])
print(selected_patch_1[i][half_patch][half_patch])

Processing file: C:\Users\Asus TUF\Documents\code\TA\Hyperspectral oil spill detection datasets\GM01.mat
Processing file: C:\Users\Asus TUF\Documents\code\TA\Hyperspectral oil spill detection datasets\GM02.mat
hsi shape
(1243, 684, 224)
creating 5 Randomly chosen 0 indices:
creating 5 Randomly chosen 1 indices:
[-237 -623  348  403  448  598  646  652  634  598  558  503  465  436
  406  384  360  332  307  278  252  229  205  181  158  145  138  123
  119  108  105  101   96   94   83   72   53   50   40    5   24   38
   47    9   24   32   26   26   21    5    5   16   28   24   25   18
   19    1   -5   -7   -6  -55 -152  -71   -8    4   14   19   24   27
   25   21   25   21   24   19   14    9   -1  -31 -236 -277 -118 -126
  -38   -5    0    0    4   -4    6    6   12   16   16   21   16   33
   25   11   17   17    0  -22  -30  -34    0    0    0    0    0    0
    0    0 -227  -39  -72  -50  -53  -23   -8    0   12   15   14   15
   23   23   22   25   21   23   21   29   28   

In [9]:
indices = random_indices_0 +  random_indices_1

# Concatenating along axis 0
x_train = np.concatenate((selected_patch_0, selected_patch_1), )

y_train = np.array([])

gt = hsi_.gt
for indice in indices:
    # print(gt[indice[0]][indice[1]])
    y_train = np.append(y_train, gt[indice[0]][indice[1]])

count = np.count_nonzero(y_train == 0)  # Count elements equal to 0
print(f'number of element equal 0 {count}')

count = np.count_nonzero(y_train == 1)  # Count elements equal to 1
print(f'number of element equal 1 {count}')



# Print shape to verify
print(f"x_train shape: {x_train.shape}")  # Expected output: (10, 9, 9, 224)
print(f"y_train shape: {y_train.shape}") 


number of element equal 0 5
number of element equal 1 5
x_train shape: (10, 9, 9, 224)
y_train shape: (10,)


In [10]:
n_category = 2
band_size = 224
num_per_category = 50

data_augment1, label_augment1 = aug.Augment_data(x_train, y_train, n_category, patch_size, band_size, num_per_category)

data_augment2, label_augment2 = aug.Augment_data2(x_train, y_train, n_category, patch_size, band_size, 200)

print(f"hasil augmentasi 1 shape: {data_augment1.shape}")
print(f"label augmentai 1 shape: {label_augment1.shape}")

print(f"hasil augmentasi 2 shape: {data_augment2.shape}")
print(f"label augmentasi 2 shape: {label_augment2.shape}")

print(label_augment1)
print(label_augment2)

# Count occurrences of each unique element
counts1 = np.bincount(label_augment1)

# Print results
for i, count in enumerate(counts1):
    print(f"Element {i} occurs {count} times.")

counts2 = np.bincount(label_augment2)

# Print results
for i, count in enumerate(counts2):
    print(f"Element {i} occurs {count} times.")

print(label_augment1[3])

j:  100
hasil augmentasi 1 shape: (100, 9, 9, 224)
label augmentai 1 shape: (100,)
hasil augmentasi 2 shape: (400, 9, 9, 224)
label augmentasi 2 shape: (400,)
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 

In [11]:
data_augment = np.concatenate((data_augment1, data_augment2))
label_augment = np.concatenate((label_augment1, label_augment2))

print(f"hasil augmentasi gabungan untuk training: {data_augment.shape}")
print(f"label augmentasi gabungan: {label_augment.shape}")

# print(label_augment)

# Count occurrences of each unique element
counts = np.bincount(label_augment)

# Print results
for i, count in enumerate(counts):
    print(f"Element {i} occurs {count} times.")

hasil augmentasi gabungan untuk training: (500, 9, 9, 224)
label augmentasi gabungan: (500,)
Element 0 occurs 250 times.
Element 1 occurs 250 times.


In [12]:
test = data_augment[0]
test = torch.tensor(test)
test = test.to(torch.float32)
test = test.unsqueeze(0)

input = test
input = input.permute(0, 3, 1, 2)

test2 = data_augment[1]
test2 = torch.tensor(test2)
test2 = test2.to(torch.float32)
test2 = test2.unsqueeze(0)

input2 = test2
input2 = input2.permute(0, 3, 1, 2)

print(f"input shape: {input.shape}")
print(f"input2 shape: {input2.shape}")

# Pass the input through the model
model.eval()
output = model(input)

print(output)

input shape: torch.Size([1, 224, 9, 9])
input2 shape: torch.Size([1, 224, 9, 9])
tensor([[0.6780, 0.3978]], grad_fn=<AddmmBackward0>)


In [13]:
lr = 0.1
batch_size = 50
momentum = 0.9
weight_decay = 0

init_lr = lr * batch_size / 256

torch.cuda.set_device(gpu)
model = model.cuda(gpu)

# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().cuda(gpu)

# optimize only the linear classifier
parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
assert len(parameters) == 2  # fc.weight, fc.bias

optimizer = torch.optim.SGD(parameters, init_lr,
                            momentum=momentum,
                            weight_decay=weight_decay)

cudnn.benchmark = True

In [14]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

augmentation = [
    transforms.RandomHorizontalFlip(),  # Flip along width
    transforms.RandomVerticalFlip(),    # Flip along height
    transforms.RandomRotation(20),      # Rotate image slightly
    transforms.Normalize(mean=[0.5], std=[0.5])  # Normalize hyperspectral data
]

transform = transforms.Compose(augmentation)

print(data_augment.shape)

(500, 9, 9, 224)


In [15]:
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

class CustomDataset(Dataset):
    def __init__(self, images, labels, transform=None):
        """
        Args:
            images (Tensor or list of Tensors): Preloaded images of shape (N, 9, 9, 224)
            transform (callable, optional): Optional transform to be applied on an image.
        """
        self.images = images  # Assuming it's a list or tensor
        self.transform = transform
        self.label = labels

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx]
        label = self.label[idx]
        
        if self.transform:
            img1 = self.transform(img)  # First augmentation
        
            return img1, label  # Return both augmented versions
        
        return img, label  # If no transform is provided, return the original image twice


# Example usage
preloaded_images = data_augment  # Example tensor with 100 images
X = torch.tensor(preloaded_images)
X= X.to(torch.float32)
X = X.permute(0, 3, 1, 2)
print(f"X_train shape: {X.shape}")

y = torch.tensor(label_augment)

# Define transformations if needed
transform = transforms.Compose(augmentation)

X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)
print(f"Train shape: {X_train.shape}, Validation shape: {X_val.shape}")

train_dataset = CustomDataset(X_train, y_train, transform=transform)
val_dataset = CustomDataset(X_val, y_val, transform=transform)

train_sampler = None

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True, drop_last=True)

# 7. Check Output
print("tes")
batch1 = next(iter(train_loader))
print("tes2")
print(batch1[1].size())
print(f"Train loader size: {len(train_loader)}, Validation loader size: {len(val_loader)}")

X_train shape: torch.Size([500, 224, 9, 9])
Train shape: torch.Size([400, 224, 9, 9]), Validation shape: torch.Size([100, 224, 9, 9])
tes
tes2
torch.Size([50])
Train loader size: 8, Validation loader size: 2


In [27]:
def train(train_loader, model, criterion, optimizer, epoch):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, data_time, losses, top1, top5],
        prefix="Epoch: [{}]".format(epoch))

    """
    Switch to eval mode:
    Under the protocol of linear classification on frozen features/models,
    it is not legitimate to change any part of the pre-trained model.
    BatchNorm in train mode may revise running mean/std (even if it receives
    no gradient), which are part of the model parameters too.
    """
    model.eval()

    end = time.time()
    for i, (images, target) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        gpu = 0
        images = images.cuda(gpu, non_blocking=True)
        target = target.cuda(gpu, non_blocking=True)

        # compute output
        output = model(images)
        loss = criterion(output, target)

        # measure accuracy and record loss
        acc1, = accuracy(output, target, topk=(1,))
        losses.update(loss.item(), images.size(0))
        top1.update(acc1[0], images.size(0))


        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        print_freq = 10
        if i % print_freq == 0:
            progress.display(i)


def validate(val_loader, model, criterion):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(
        len(val_loader),
        [batch_time, losses, top1, top5],
        prefix='Test: ')

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):
            gpu = 0
            images = images.cuda(gpu, non_blocking=True)
            target = target.cuda(gpu, non_blocking=True)

            # compute output
            output = model(images)
            loss = criterion(output, target)

            # measure accuracy and record loss
            acc1, = accuracy(output, target, topk=(1,))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1[0], images.size(0))
            # top5.update(acc5[0], images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            print_freq = 10
            if i % print_freq == 0:
                progress.display(i)

        # TODO: this should also be done with the ProgressMeter
        print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
              .format(top1=top1, top5=top5))

    return top1.avg


def save_checkpoint(state, is_best, filename='models/checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')


def sanity_check(state_dict, pretrained_weights):
    """
    Linear classifier should not change any weights other than the linear layer.
    This sanity check asserts nothing wrong happens (e.g., BN stats updated).
    """
    print("=> loading '{}' for sanity check".format(pretrained_weights))
    checkpoint = torch.load(pretrained_weights, map_location="cpu")
    state_dict_pre = checkpoint['state_dict']

    for k in list(state_dict.keys()):
        # Ignore fc layer
        if 'fc.weight' in k or 'fc.bias' in k:
            continue

        # Adjust key mapping to match checkpoint format
        k_pre = k.replace('module.encoder.', '')  # Remove unnecessary prefix

        # Skip missing keys
        if k_pre not in state_dict_pre:
            print(f"Warning: {k_pre} not found in pretrained model. Skipping...")
            continue

        # Check if tensor shapes match before comparing values
        if state_dict[k].shape != state_dict_pre[k_pre].shape:
            print(f"Warning: Shape mismatch for {k}: {state_dict[k].shape} vs {state_dict_pre[k_pre].shape}. Skipping...")
            continue

        # Assert that the weights remain unchanged
        assert ((state_dict[k].cpu() == state_dict_pre[k_pre]).all()), \
            '{} is changed in linear classifier training.'.format(k)

    print("=> sanity check passed.")



class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'


def adjust_learning_rate(optimizer, init_lr, epoch, epochs):
    """Decay the learning rate based on schedule"""
    cur_lr = init_lr * 0.5 * (1. + math.cos(math.pi * epoch / epochs))
    for param_group in optimizer.param_groups:
        param_group['lr'] = cur_lr


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


In [28]:
start_epoch = 0
epochs = 5

for epoch in range(start_epoch, epochs):
   
    adjust_learning_rate(optimizer, init_lr, epoch, epochs)

    # train for one epoch
    train(train_loader, model, criterion, optimizer, epoch)

    # evaluate on validation set
    acc1 = validate(val_loader, model, criterion)

    # remember best acc@1 and save checkpoint
    is_best = acc1 > best_acc1
    best_acc1 = max(acc1, best_acc1)

    arch = 'vgg16'
    save_checkpoint({
        'epoch': epoch + 1,
        'arch': arch,
        'state_dict': model.state_dict(),
        'best_acc1': best_acc1,
        'optimizer' : optimizer.state_dict(),
    }, is_best)
    pretrained = r'C:\Users\Asus TUF\Documents\code\TA\simsiam\simsiam\models\checkpoint_0009.pth.tar'
    if epoch == start_epoch:
        sanity_check(model.state_dict(), pretrained)

Epoch: [0][0/8]	Time  0.306 ( 0.306)	Data  0.129 ( 0.129)	Loss 4.9816e-01 (4.9816e-01)	Acc@1  96.00 ( 96.00)	Acc@5   0.00 (  0.00)
Test: [0/2]	Time  0.126 ( 0.126)	Loss 2.1182e-02 (2.1182e-02)	Acc@1 100.00 (100.00)	Acc@5   0.00 (  0.00)
 * Acc@1 98.000 Acc@5 0.000
=> loading 'C:\Users\Asus TUF\Documents\code\TA\simsiam\simsiam\models\checkpoint_0009.pth.tar' for sanity check


  checkpoint = torch.load(pretrained_weights, map_location="cpu")


=> sanity check passed.
Epoch: [1][0/8]	Time  0.191 ( 0.191)	Data  0.127 ( 0.127)	Loss 1.5987e-01 (1.5987e-01)	Acc@1  94.00 ( 94.00)	Acc@5   0.00 (  0.00)
Test: [0/2]	Time  0.135 ( 0.135)	Loss 1.1848e-01 (1.1848e-01)	Acc@1  94.00 ( 94.00)	Acc@5   0.00 (  0.00)
 * Acc@1 97.000 Acc@5 0.000
Epoch: [2][0/8]	Time  0.094 ( 0.094)	Data  0.049 ( 0.049)	Loss 5.6304e-02 (5.6304e-02)	Acc@1  98.00 ( 98.00)	Acc@5   0.00 (  0.00)
Test: [0/2]	Time  0.124 ( 0.124)	Loss 4.0512e-02 (4.0512e-02)	Acc@1  98.00 ( 98.00)	Acc@5   0.00 (  0.00)
 * Acc@1 97.000 Acc@5 0.000
Epoch: [3][0/8]	Time  0.122 ( 0.122)	Data  0.055 ( 0.055)	Loss 1.8085e-01 (1.8085e-01)	Acc@1  94.00 ( 94.00)	Acc@5   0.00 (  0.00)
Test: [0/2]	Time  0.126 ( 0.126)	Loss 4.2719e-02 (4.2719e-02)	Acc@1  98.00 ( 98.00)	Acc@5   0.00 (  0.00)
 * Acc@1 99.000 Acc@5 0.000
Epoch: [4][0/8]	Time  0.150 ( 0.150)	Data  0.063 ( 0.063)	Loss 5.2699e-02 (5.2699e-02)	Acc@1  98.00 ( 98.00)	Acc@5   0.00 (  0.00)
Test: [0/2]	Time  0.121 ( 0.121)	Loss 1.3909e-02 (