In [2]:
from functools import partial
from typing import Sequence, Tuple, Union

import lightning as L
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.transforms.functional as VisionF
from pytorch_lightning import Callback, LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from torch import Tensor
from torch.utils.data import DataLoader
from torchmetrics.functional import accuracy
from torchvision.datasets import CIFAR10
from torchvision.models.resnet import resnet34
from torchvision.utils import make_grid
from torchvision import models
from torchsummary import summary
from pytorch_lightning.loggers import CSVLogger



logger = CSVLogger("logs_out", name="encoder_logs")


batch_size = 64
num_workers = 8
max_epochs = 200
z_dim = 1024

In [3]:
class BarlowTwinsTransform:
    def __init__(self, train=True, input_height=224, gaussian_blur=True, jitter_strength=1.0, normalize=None):
        self.input_height = input_height
        self.gaussian_blur = gaussian_blur
        self.jitter_strength = jitter_strength
        self.normalize = normalize
        self.train = train

        color_jitter = transforms.ColorJitter(
            0.8 * self.jitter_strength,
            0.8 * self.jitter_strength,
            0.8 * self.jitter_strength,
            0.2 * self.jitter_strength,
        )

        color_transform = [transforms.RandomApply([color_jitter], p=0.8), transforms.RandomGrayscale(p=0.2)]

        if self.gaussian_blur:
            kernel_size = int(0.1 * self.input_height)
            if kernel_size % 2 == 0:
                kernel_size += 1

            color_transform.append(transforms.RandomApply([transforms.GaussianBlur(kernel_size=kernel_size)], p=0.5))

        self.color_transform = transforms.Compose(color_transform)

        if normalize is None:
            self.final_transform = transforms.ToTensor()
        else:
            self.final_transform = transforms.Compose([transforms.ToTensor(), normalize])

        self.transform = transforms.Compose(
            [
                transforms.RandomResizedCrop(self.input_height),
                transforms.RandomHorizontalFlip(p=0.5),
                self.color_transform,
                self.final_transform,
            ]
        )

        self.finetune_transform = None
        if self.train:
            self.finetune_transform = transforms.Compose(
                [
                    transforms.RandomCrop(32, padding=4, padding_mode="reflect"),
                    transforms.RandomHorizontalFlip(),
                    transforms.ToTensor(),
                ]
            )
        else:
            self.finetune_transform = transforms.ToTensor()

    def __call__(self, sample):
        return self.transform(sample), self.transform(sample), self.finetune_transform(sample)

In [4]:
def cifar10_normalization():
    normalize = transforms.Normalize(
        mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], std=[x / 255.0 for x in [63.0, 62.1, 66.7]]
    )
    return normalize


train_transform = BarlowTwinsTransform(
    train=True, input_height=32, gaussian_blur=False, jitter_strength=0.5, normalize=cifar10_normalization()
)
train_dataset = CIFAR10(root=".", train=True, download=True, transform=train_transform)

val_transform = BarlowTwinsTransform(
    train=False, input_height=32, gaussian_blur=False, jitter_strength=0.5, normalize=cifar10_normalization()
)
val_dataset = CIFAR10(root=".", train=False, download=True, transform=train_transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True,pin_memory = True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, drop_last=True, pin_memory = True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [01:17<00:00, 2188961.02it/s]


Extracting ./cifar-10-python.tar.gz to .
Files already downloaded and verified


In [5]:
encoder = resnet34()

# for CIFAR10, replace the first 7x7 conv with smaller 3x3 conv and remove the first maxpool
encoder.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
encoder.maxpool = nn.MaxPool2d(kernel_size=1, stride=1)

# replace classification fc layer of Resnet to obtain representations from the backbone
encoder.fc = nn.Identity()

In [6]:
class ProjectionHead(nn.Module):
    def __init__(self, input_dim=1024, hidden_dim=1024, output_dim=1024):
        super().__init__()

        self.projection_head = nn.Sequential(
            nn.Linear(input_dim, hidden_dim, bias=True),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim, bias=False),
        )

    def forward(self, x):
        return self.projection_head(x)

In [7]:
def fn(warmup_steps, step):
    if step < warmup_steps:
        return float(step) / float(max(1, warmup_steps))
    else:
        return 1.0


def linear_warmup_decay(warmup_steps):
    return partial(fn, warmup_steps)

In [9]:
class BarlowTwins(LightningModule):
    def __init__(
        self,
        encoder,
        encoder_out_dim,
        num_training_samples,
        batch_size,
        lambda_coeff=5e-3,
        z_dim=128,
        learning_rate=1e-4,
        warmup_epochs=10,
        max_epochs=200,
    ):
        super().__init__()

        self.encoder = encoder
        self.projection_head = ProjectionHead(input_dim=encoder_out_dim, hidden_dim=encoder_out_dim, output_dim=z_dim)
        self.loss_fn = BarlowTwinsLoss(batch_size=batch_size, lambda_coeff=lambda_coeff, z_dim=z_dim)

        self.learning_rate = learning_rate
        self.warmup_epochs = warmup_epochs
        self.max_epochs = max_epochs

        self.train_iters_per_epoch = num_training_samples // batch_size

    def forward(self, x):
        return self.encoder(x)

    def shared_step(self, batch):
        (x1, x2, _), _ = batch

        z1 = self.projection_head(self.encoder(x1))
        z2 = self.projection_head(self.encoder(x2))

        loss, on_diag, off_diag= self.loss_fn(z1, z2)
       
        self.log("on_diag", on_diag.sum(), on_step=True, on_epoch=True, prog_bar= True, logger=logger)
        self.log("off_diag", off_diag.sum(), on_step=True, on_epoch=True, prog_bar= True, logger=logger)

        return loss
    
    def training_step(self, batch, batch_idx):
        loss = self.shared_step(batch)
        self.log("train_loss", loss, on_step=True, on_epoch=False)
        return loss

    def validation_step(self, batch, batch_idx):
        loss = self.shared_step(batch)
        self.log("val_loss", loss, on_step=False, on_epoch=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)

        warmup_steps = self.train_iters_per_epoch * self.warmup_epochs

        scheduler = {
            "scheduler": torch.optim.lr_scheduler.LambdaLR(
                optimizer,
                linear_warmup_decay(warmup_steps),
            ),
            "interval": "step",
            "frequency": 1,
        }

        return [optimizer], [scheduler]

In [10]:
class OnlineFineTuner(Callback):
    def __init__(
        self,
        encoder_output_dim: int,
        num_classes: int,
    ) -> None:
        super().__init__()

        self.optimizer: torch.optim.Optimizer

        self.encoder_output_dim = encoder_output_dim
        self.num_classes = num_classes

    def on_fit_start(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None:
        # add linear_eval layer and optimizer
        pl_module.online_finetuner = nn.Linear(self.encoder_output_dim, self.num_classes).to(pl_module.device)
        self.optimizer = torch.optim.Adam(pl_module.online_finetuner.parameters(), lr=1e-4)

    def extract_online_finetuning_view(
        self, batch: Sequence, device: Union[str, torch.device]
    ) -> Tuple[Tensor, Tensor]:
        (_, _, finetune_view), y = batch
        finetune_view = finetune_view.to(device)
        y = y.to(device)

        return finetune_view, y

    def on_train_batch_end(
        self,
        trainer: L.Trainer,
        pl_module: L.LightningModule,
        outputs: Sequence,
        batch: Sequence,
        batch_idx: int,
    ) -> None:
        x, y = self.extract_online_finetuning_view(batch, pl_module.device)

        with torch.no_grad():
            feats = pl_module(x)

        feats = feats.detach()
        preds = pl_module.online_finetuner(feats)
        loss = F.cross_entropy(preds, y)

        loss.backward()
        self.optimizer.step()
        self.optimizer.zero_grad()

        acc = accuracy(F.softmax(preds, dim=1), y, task="multiclass", num_classes=10)
        pl_module.log("online_train_acc", acc, on_step=True, on_epoch=False)
        pl_module.log("online_train_loss", loss, on_step=True, on_epoch=False)

    def on_validation_batch_end(
        self,
        trainer: L.Trainer,
        pl_module: L.LightningModule,
        outputs: Sequence,
        batch: Sequence,
        batch_idx: int,
    ) -> None:
        x, y = self.extract_online_finetuning_view(batch, pl_module.device)

        with torch.no_grad():
            feats = pl_module(x)

        feats = feats.detach()
        preds = pl_module.online_finetuner(feats)
        loss = F.cross_entropy(preds, y)

        acc = accuracy(F.softmax(preds, dim=1), y, task="multiclass", num_classes=10)
        pl_module.log("online_val_acc", acc, on_step=False, on_epoch=True, sync_dist=True)
        pl_module.log("online_val_loss", loss, on_step=False, on_epoch=True, sync_dist=True)

In [13]:
class BarlowTwinsLoss(nn.Module):
    def __init__(self, batch_size, lambda_coeff=5e-3, z_dim=128):
        super().__init__()

        self.z_dim = z_dim
        self.batch_size = batch_size
        self.lambda_coeff = lambda_coeff

    def off_diagonal_ele(self, x):
        # taken from: https://github.com/facebookresearch/barlowtwins/blob/main/main.py
        # return a flattened view of the off-diagonal elements of a square matrix
        n, m = x.shape
        assert n == m
        return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()

    def forward(self, z1, z2):
        # N x D, where N is the batch size and D is output dim of projection head
        z1_norm = (z1 - torch.mean(z1, dim=0)) / torch.std(z1, dim=0)
        z2_norm = (z2 - torch.mean(z2, dim=0)) / torch.std(z2, dim=0)

        cross_corr = torch.mm(z1_norm.T, z2_norm) / self.batch_size

        on_diag = torch.diagonal(cross_corr).add_(-1).pow_(2).sum()
        off_diag = self.off_diagonal_ele(cross_corr).pow_(2).sum()

        return on_diag + self.lambda_coeff * off_diag, on_diag, off_diag

In [16]:
encoder_out_dim = 512

model = BarlowTwins(
    encoder=encoder,
    encoder_out_dim=encoder_out_dim,
    num_training_samples=len(train_dataset),
    batch_size=batch_size,
    z_dim=z_dim,
)

online_finetuner = OnlineFineTuner(encoder_output_dim=encoder_out_dim, num_classes=10)
checkpoint_callback = ModelCheckpoint(every_n_epochs=100, save_top_k=-1, save_last=True)

trainer = Trainer(
    max_epochs=max_epochs,
    accelerator="auto",
    devices=1 if torch.cuda.is_available() else None,  # limiting got iPython runs
    callbacks=[#online_finetuner,
        checkpoint_callback],
    logger=logger
)

# uncomment this to train the model
# this is done for the tutorial so that the notebook compiles
trainer.fit(model, train_loader,)#val_loader)

ModelCheckpoint(save_last=True, save_top_k=-1, monitor=None) will duplicate the last checkpoint saved.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name            | Type            | Params
----------------------------------------------------
0 | encoder         | ResNet          | 21.3 M
1 | projection_head | ProjectionHead  | 787 K 
2 | loss_fn         | BarlowTwinsLoss | 0     
----------------------------------------------------
22.1 M    Trainable params
0         Non-trainable params
22.1 M    Total params
88.260    Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=200` reached.


In [22]:
class classifire(nn.Module):
    def __init__(self):
        super().__init__()
        self.block = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features=512, out_features=1024),
            nn.ReLU(),
            nn.Linear(in_features=1024, out_features=1024),
            nn.ReLU(),
            nn.Linear(in_features=1024, out_features=10),
        )
    
    def forward(self, x:torch.Tensor):
        x = self.block(x)
        return x


In [23]:
classifire_model = classifire()

In [24]:
ai_optimizer = torch.optim.Adam(classifire_model.parameters(), lr = 1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(ai_optimizer, 200)
goai_loss_fn = nn.CrossEntropyLoss()
# ai_optimizer = torch.optim.SGD(classifire_model.parameters(), lr=1e-3,
#                                 momentum=0.9, weight_decay=5e-4)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(ai_optimizer, 200)

In [25]:
from tqdm import tqdm

In [26]:

running_loss = 0.0

running_corrects = 0
epochs =  100

for epoch in tqdm(range(epochs)):
    running_loss = 0.0
    running_corrects = 0
    for batch in train_loader:
        (img1, img2, _), label = batch
        img1 = img1.to('cuda')
        model = model.to('cuda')
        label = label.to('cuda')
        model.eval()
        classifire_model.to('cuda')
        classifire_model.train()

        encoder_out = model.forward(img1)
        input_to_goai = torch.unsqueeze(encoder_out, dim=1)             ### these unsuqeese is done to make the data follow NCHW pattern which is used inside the pytorch as a data format
        input_to_goai = torch.unsqueeze(input_to_goai, dim=1)
        #y_pred = GOAI_model(encoder_out)
        
       # GOAI_model.to('cuda')
        y_pred = classifire_model(input_to_goai)
        _, preds = torch.max(y_pred, 1)
        
        loss = goai_loss_fn(y_pred, label)
        ai_optimizer.zero_grad()

        loss.backward()

        ai_optimizer.step()
        running_loss += loss.item() * batch_size
        running_corrects += torch.sum(preds == label.data)
    scheduler.step()
    epoch_loss = running_loss / len(train_loader.dataset)
    epoch_acc = running_corrects.double() / len(train_loader.dataset)

    print('Epoch: {} Loss: {:.4f} Acc: {:.4f}'.format(epoch,  epoch_loss, epoch_acc))




Epoch: 0 Loss: 0.6874 Acc: 0.7728




Epoch: 1 Loss: 0.6142 Acc: 0.7869




Epoch: 2 Loss: 0.5975 Acc: 0.7927




Epoch: 3 Loss: 0.5900 Acc: 0.7939




Epoch: 4 Loss: 0.5756 Acc: 0.7970




Epoch: 5 Loss: 0.5747 Acc: 0.7978




Epoch: 6 Loss: 0.5660 Acc: 0.8009




Epoch: 7 Loss: 0.5612 Acc: 0.8005




Epoch: 8 Loss: 0.5562 Acc: 0.8047




Epoch: 9 Loss: 0.5460 Acc: 0.8074




Epoch: 10 Loss: 0.5505 Acc: 0.8041




Epoch: 11 Loss: 0.5454 Acc: 0.8070




Epoch: 12 Loss: 0.5412 Acc: 0.8080




Epoch: 13 Loss: 0.5387 Acc: 0.8085




Epoch: 14 Loss: 0.5370 Acc: 0.8085




Epoch: 15 Loss: 0.5295 Acc: 0.8133




Epoch: 16 Loss: 0.5285 Acc: 0.8120




Epoch: 17 Loss: 0.5282 Acc: 0.8128




Epoch: 18 Loss: 0.5257 Acc: 0.8123




Epoch: 19 Loss: 0.5175 Acc: 0.8159




Epoch: 20 Loss: 0.5176 Acc: 0.8151




Epoch: 21 Loss: 0.5188 Acc: 0.8163




Epoch: 22 Loss: 0.5162 Acc: 0.8155




Epoch: 23 Loss: 0.5124 Acc: 0.8184




Epoch: 24 Loss: 0.5112 Acc: 0.8174




Epoch: 25 Loss: 0.5093 Acc: 0.8191




Epoch: 26 Loss: 0.5014 Acc: 0.8205




Epoch: 27 Loss: 0.5028 Acc: 0.8207




Epoch: 28 Loss: 0.4966 Acc: 0.8232




Epoch: 29 Loss: 0.4982 Acc: 0.8212




Epoch: 30 Loss: 0.4976 Acc: 0.8224




Epoch: 31 Loss: 0.4980 Acc: 0.8223




Epoch: 32 Loss: 0.4926 Acc: 0.8242




Epoch: 33 Loss: 0.4880 Acc: 0.8255




Epoch: 34 Loss: 0.4900 Acc: 0.8264




Epoch: 35 Loss: 0.4834 Acc: 0.8272




Epoch: 36 Loss: 0.4872 Acc: 0.8273




Epoch: 37 Loss: 0.4815 Acc: 0.8267




Epoch: 38 Loss: 0.4883 Acc: 0.8266




Epoch: 39 Loss: 0.4757 Acc: 0.8310




Epoch: 40 Loss: 0.4754 Acc: 0.8311




Epoch: 41 Loss: 0.4756 Acc: 0.8309




Epoch: 42 Loss: 0.4712 Acc: 0.8332




Epoch: 43 Loss: 0.4730 Acc: 0.8306




Epoch: 44 Loss: 0.4727 Acc: 0.8294




Epoch: 45 Loss: 0.4710 Acc: 0.8321




Epoch: 46 Loss: 0.4668 Acc: 0.8335




Epoch: 47 Loss: 0.4618 Acc: 0.8352




Epoch: 48 Loss: 0.4625 Acc: 0.8350




Epoch: 49 Loss: 0.4597 Acc: 0.8354




Epoch: 50 Loss: 0.4610 Acc: 0.8351




Epoch: 51 Loss: 0.4600 Acc: 0.8359




Epoch: 52 Loss: 0.4560 Acc: 0.8364




Epoch: 53 Loss: 0.4527 Acc: 0.8373




Epoch: 54 Loss: 0.4548 Acc: 0.8380




Epoch: 55 Loss: 0.4487 Acc: 0.8390




Epoch: 56 Loss: 0.4489 Acc: 0.8406




Epoch: 57 Loss: 0.4507 Acc: 0.8387




Epoch: 58 Loss: 0.4482 Acc: 0.8384




Epoch: 59 Loss: 0.4452 Acc: 0.8420




Epoch: 60 Loss: 0.4462 Acc: 0.8400




Epoch: 61 Loss: 0.4394 Acc: 0.8427




Epoch: 62 Loss: 0.4399 Acc: 0.8418




Epoch: 63 Loss: 0.4364 Acc: 0.8448




Epoch: 64 Loss: 0.4357 Acc: 0.8453




Epoch: 65 Loss: 0.4349 Acc: 0.8445




Epoch: 66 Loss: 0.4327 Acc: 0.8468




Epoch: 67 Loss: 0.4337 Acc: 0.8447




Epoch: 68 Loss: 0.4307 Acc: 0.8468




Epoch: 69 Loss: 0.4318 Acc: 0.8448




Epoch: 70 Loss: 0.4351 Acc: 0.8447




Epoch: 71 Loss: 0.4338 Acc: 0.8450




Epoch: 72 Loss: 0.4221 Acc: 0.8501




Epoch: 73 Loss: 0.4175 Acc: 0.8508




Epoch: 74 Loss: 0.4246 Acc: 0.8478




Epoch: 75 Loss: 0.4212 Acc: 0.8511




Epoch: 76 Loss: 0.4231 Acc: 0.8498




Epoch: 77 Loss: 0.4188 Acc: 0.8501




Epoch: 78 Loss: 0.4131 Acc: 0.8526




Epoch: 79 Loss: 0.4193 Acc: 0.8518




Epoch: 80 Loss: 0.4126 Acc: 0.8525




Epoch: 81 Loss: 0.4152 Acc: 0.8519




Epoch: 82 Loss: 0.4152 Acc: 0.8537




Epoch: 83 Loss: 0.4056 Acc: 0.8555




Epoch: 84 Loss: 0.4077 Acc: 0.8552




Epoch: 85 Loss: 0.4138 Acc: 0.8539




Epoch: 86 Loss: 0.4094 Acc: 0.8547




Epoch: 87 Loss: 0.4073 Acc: 0.8563




Epoch: 88 Loss: 0.4031 Acc: 0.8556




Epoch: 89 Loss: 0.4093 Acc: 0.8541




Epoch: 90 Loss: 0.4070 Acc: 0.8554




Epoch: 91 Loss: 0.4009 Acc: 0.8583




Epoch: 92 Loss: 0.4024 Acc: 0.8571




Epoch: 93 Loss: 0.4013 Acc: 0.8577




Epoch: 94 Loss: 0.3966 Acc: 0.8586




Epoch: 95 Loss: 0.3994 Acc: 0.8597




Epoch: 96 Loss: 0.3997 Acc: 0.8576




Epoch: 97 Loss: 0.4007 Acc: 0.8572




Epoch: 98 Loss: 0.3958 Acc: 0.8602


100%|██████████| 100/100 [46:20<00:00, 27.81s/it]

Epoch: 99 Loss: 0.3893 Acc: 0.8620





In [28]:
correct = 0
total = 0
with torch.inference_mode():
    for batch in val_loader:
         (img1, img2, _), label = batch
         model.eval()
         classifire_model.eval()
         model.to('cpu')
         classifire_model.to('cpu')
    
         encoder_out = model.forward(img1)
         input_to_goai = torch.unsqueeze(encoder_out, dim=1)             ### these unsuqeese is done to make the data follow NCHW pattern which is used inside the pytorch as a data format
         input_to_goai = torch.unsqueeze(input_to_goai, dim=1)
         y_pred = classifire_model(input_to_goai)
         _, preds = torch.max(y_pred.data, 1)
         total += label.size(0)
         correct += (preds == label).sum().item()
# epoch_acc = running_corrects.double() / len(val_loader.dataset)
# print('Acc: {:.4f}'.format(epoch_acc))
print('Accuracy of the network on the test images: %d %%' % (100 * correct / total))     

Accuracy of the network on the test images: 79 %


In [37]:
class classifire(nn.Module):
    def __init__(self):
        super().__init__()
        self.block = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features=512, out_features=1024),
            nn.ReLU(),
            nn.Linear(in_features=1024, out_features=1024),
            nn.ReLU(),
            nn.Linear(in_features=1024, out_features=1024),
            nn.LeakyReLU(),
            nn.Linear(in_features=1024, out_features=10)
        )
    
    def forward(self, x:torch.Tensor):
        x = self.block(x)
        return x


In [43]:
##end to end ml
## use resent-9 as per planned before
GOAI_model = models.resnet18(pretrained=False).to('cuda')
GOAI_model.conv1 = nn.Conv2d(1, 64, kernel_size=(3,3), stride=(1, 1), padding=(1, 1), bias=False)
num_features = GOAI_model.fc.in_features
GOAI_model.fc = nn.Linear(num_features, 10)

ai_optimizer = torch.optim.Adam(GOAI_model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(ai_optimizer, 100)
goai_loss_fn = nn.CrossEntropyLoss()



In [44]:
# classifire_model = classifire()
# ai_optimizer = torch.optim.Adam(classifire_model.parameters(), lr = 1e-4)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(ai_optimizer, 100)
# goai_loss_fn = nn.CrossEntropyLoss()
# # ai_optimizer = torch.optim.SGD(classifire_model.parameters(), lr=1e-3,
# #                                 momentum=0.9, weight_decay=5e-4)
# # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(ai_optimizer, 200)

In [47]:

running_loss = 0.0

running_corrects = 0
epochs =  100

for epoch in tqdm(range(epochs)):
    running_loss = 0.0
    running_corrects = 0
    for batch in train_loader:
        (img1, img2, _), label = batch
        img1 = img1.to('cuda')
        model = model.to('cuda')
        label = label.to('cuda')
        model.eval()
        classifire_model.to('cuda')
        classifire_model.train()

        encoder_out = model.forward(img1)
        input_to_goai = torch.unsqueeze(encoder_out, dim=1)             ### these unsuqeese is done to make the data follow NCHW pattern which is used inside the pytorch as a data format
        input_to_goai = torch.unsqueeze(input_to_goai, dim=1)
        GOAI_model.to('cuda')
        y_pred = GOAI_model(input_to_goai)
        
        #GOAI_model.to('cuda')
        #y_pred = classifire_model(input_to_goai)
        _, preds = torch.max(y_pred, 1)
        
        loss = goai_loss_fn(y_pred, label)
        ai_optimizer.zero_grad()

        loss.backward()

        ai_optimizer.step()
        running_loss += loss.item() * batch_size
        running_corrects += torch.sum(preds == label.data)
    scheduler.step()
    epoch_loss = running_loss / len(train_loader.dataset)
    epoch_acc = running_corrects.double() / len(train_loader.dataset)

    print('Epoch: {} Loss: {:.4f} Acc: {:.4f}'.format(epoch,  epoch_loss, epoch_acc))




Epoch: 0 Loss: 0.8484 Acc: 0.7285




Epoch: 1 Loss: 0.6832 Acc: 0.7724




Epoch: 2 Loss: 0.6537 Acc: 0.7797




Epoch: 3 Loss: 0.6425 Acc: 0.7807




Epoch: 4 Loss: 0.6239 Acc: 0.7882




Epoch: 5 Loss: 0.6117 Acc: 0.7899




Epoch: 6 Loss: 0.6011 Acc: 0.7931




Epoch: 7 Loss: 0.5982 Acc: 0.7931




Epoch: 8 Loss: 0.5975 Acc: 0.7950




Epoch: 9 Loss: 0.5881 Acc: 0.7962




Epoch: 10 Loss: 0.5738 Acc: 0.8028




Epoch: 11 Loss: 0.5696 Acc: 0.8027




Epoch: 12 Loss: 0.5673 Acc: 0.8023




Epoch: 13 Loss: 0.5513 Acc: 0.8102




Epoch: 14 Loss: 0.5498 Acc: 0.8083




Epoch: 15 Loss: 0.5427 Acc: 0.8130




Epoch: 16 Loss: 0.5421 Acc: 0.8117




Epoch: 17 Loss: 0.5350 Acc: 0.8142


In [None]:
correct = 0
total = 0
with torch.inference_mode():
    for batch in val_loader:
         (img1, img2, _), label = batch
         model.eval()
         classifire_model.eval()
         model.to('cpu')
         classifire_model.to('cpu')
    
         encoder_out = model.forward(img1)
         input_to_goai = torch.unsqueeze(encoder_out, dim=1)             ### these unsuqeese is done to make the data follow NCHW pattern which is used inside the pytorch as a data format
         input_to_goai = torch.unsqueeze(input_to_goai, dim=1)
         y_pred = classifire_model(input_to_goai)
         _, preds = torch.max(y_pred.data, 1)
         total += label.size(0)
         correct += (preds == label).sum().item()
# epoch_acc = running_corrects.double() / len(val_loader.dataset)
# print('Acc: {:.4f}'.format(epoch_acc))
print('Accuracy of the network on the test images: %d %%' % (100 * correct / total))     

Accuracy of the network on the test images: 79 %
