#Testing a Barlow Twins trained RESNET50 with a classifier head for CIFAR10
For this test, we build our DNN with the RESNET50 from the Barlow Twins group as a backbone and a fully connected layer as our classifier head.

In [None]:
%%shell

# Download TorchVision repo to use some files from
# references/detection
git clone https://github.com/pytorch/vision.git
cd vision
git checkout v0.8.2

cp references/detection/utils.py ../
cp references/detection/transforms.py ../
cp references/detection/coco_eval.py ../
cp references/detection/engine.py ../
cp references/detection/coco_utils.py ../

Cloning into 'vision'...
remote: Enumerating objects: 79885, done.[K
remote: Counting objects: 100% (15019/15019), done.[K
remote: Compressing objects: 100% (1349/1349), done.[K
remote: Total 79885 (delta 13796), reused 14696 (delta 13590), pack-reused 64866
Receiving objects: 100% (79885/79885), 156.09 MiB | 41.11 MiB/s, done.
Resolving deltas: 100% (66867/66867), done.
Note: checking out 'v0.8.2'.

You are in 'detached HEAD' state. You can look around, make experimental
changes and commit them, and you can discard any commits you make in this
state without impacting any branches by performing another checkout.

If you want to create a new branch to retain commits you create, you may
do so (now or later) by using -b with the checkout command again. Example:

  git checkout -b <new-branch-name>

HEAD is now at 2f40a483d [v0.8.X] .circleci: Add Python 3.9 to CI (#3063)




In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import numpy as np
import matplotlib.pyplot as plt

from pathlib import Path
from PIL import Image, ImageOps, ImageFilter
from torchvision import models, datasets, transforms
import cv2
from google.colab.patches import cv2_imshow
from engine import evaluate

In [None]:
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

In [None]:
num_classes = 10
model = torch.hub.load('facebookresearch/barlowtwins:main', 'resnet50', pretrained=True)
set_parameter_requires_grad(model, feature_extracting=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, num_classes)

Using cache found in /root/.cache/torch/hub/facebookresearch_barlowtwins_main


Checking and assigning the parameters to update. (Must only be the weights and biases for the full connected layer)

In [None]:
params_to_update = []
for name,param in model.named_parameters():
  if param.requires_grad == True:
    params_to_update.append(param)
    print("\t",name)

	 fc.weight
	 fc.bias


Setting the transforms for the train data and calling the CIFAR10 dataset.

In [None]:
TrainTransforms = transforms.Compose([
            transforms.Resize(224),
            transforms.CenterCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

TestTransforms = transforms.Compose([
            transforms.Resize(224),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

trainset = datasets.CIFAR10(root="data/cifar10", train=True, download=True, transform=TrainTransforms)
testset = datasets.CIFAR10(root="data/cifar10", train=False, download=True, transform=TestTransforms)

Files already downloaded and verified
Files already downloaded and verified


Defining the classes for the CIFAR10 dataset

In [None]:
classes = ['Airplane', 'Car', 'Bird','Cat','Deer','Dog','Frog','Horse','Ship','Truck']

Downloading sample images to test the classification.

In [None]:
!wget https://cdn.hswstatic.com/gif/airplane-windows.jpg -O 'test1.jpg'
!wget https://www.irmi.com/images/default-source/article-images/aviation/boeing-737.jpg -O 'test2.jpg'
!wget https://images.frandroid.com/wp-content/uploads/2021/11/apple-car-concept.jpg -O 'test3.jpg'
!wget https://abcbirds.org/wp-content/uploads/2021/07/Blue-Jay-on-redbud-tree-by-Tom-Reichner_news.png -O 'test4.jpg'
!wget https://upload.wikimedia.org/wikipedia/commons/4/4d/Cat_November_2010-1a.jpg -O 'test5.jpg'
!wget https://iadsb.tmgrup.com.tr/7ddb86/0/0/0/0/1926/1086?u=https://idsb.tmgrup.com.tr/2018/05/22/horses-the-wings-of-mankind-1527015927739.jpg -O 'test6.jpg'
!wget https://carwow-uk-wp-3.imgix.net/Volvo-XC40-white-scaled.jpg -O 'test7.jpg'
!wget https://media.self.com/photos/6192b264fd75b7baf2aadbe1/4:3/w_2560%2Cc_limit/GettyImages-1219359156.jpg -O 'test8.jpg'
!wget https://upload.wikimedia.org/wikipedia/commons/d/d9/Motorboat_at_Kankaria_lake.JPG -O 'test9.jpg'
!wget https://cdn.britannica.com/84/206384-050-00698723/Javan-gliding-tree-frog.jpg -O 'test10.jpg'

--2022-01-20 16:53:10--  https://cdn.hswstatic.com/gif/airplane-windows.jpg
Resolving cdn.hswstatic.com (cdn.hswstatic.com)... 13.32.204.8, 13.32.204.52, 13.32.204.21, ...
Connecting to cdn.hswstatic.com (cdn.hswstatic.com)|13.32.204.8|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 107448 (105K) [image/jpeg]
Saving to: ‘test1.jpg’


2022-01-20 16:53:10 (33.2 MB/s) - ‘test1.jpg’ saved [107448/107448]

--2022-01-20 16:53:10--  https://www.irmi.com/images/default-source/article-images/aviation/boeing-737.jpg
Resolving www.irmi.com (www.irmi.com)... 104.18.162.71, 104.18.163.71, 2606:4700::6812:a347, ...
Connecting to www.irmi.com (www.irmi.com)|104.18.162.71|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 48515 (47K) [image/jpeg]
Saving to: ‘test2.jpg’


2022-01-20 16:53:11 (145 MB/s) - ‘test2.jpg’ saved [48515/48515]

--2022-01-20 16:53:11--  https://images.frandroid.com/wp-content/uploads/2021/11/apple-car-concept.jpg
Resolving images

In [None]:
path_test = ['./test1.jpg','./test2.jpg','./test3.jpg','./test4.jpg','./test5.jpg','./test6.jpg','./test7.jpg','./test8.jpg','./test9.jpg','./test10.jpg']
label_test = ['Airplane','Airplane','Car', 'Bird', 'Cat', 'Horse', 'Car', 'Deer', 'Ship','Frog']

Building a train function to train the classifier head.

In [None]:
def train(model, dataloader, dataloader_test, path_test, label_test, params, nepochs=1, lr=1e-3):
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    optimizer = torch.optim.Adam(params,weight_decay=1e-5)
    criterion = nn.CrossEntropyLoss()
    model = model.to(device)
    criterion = criterion.to(device)

    running_loss = 0.
    running_samples = 0
    for epoch in range(nepochs):
        model.train()
        for it, data in enumerate(dataloader):
            ims, labels = data
            ims = ims.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            out = model(ims)
            loss = criterion(out, labels)
            running_loss += loss
            running_samples += ims.shape[0]

            if it % 100 == 0:
                print(f'ep: {epoch}, it: {it}, loss : {running_loss/running_samples:.5f}')
                running_loss = 0.
                running_samples = 0

            loss.backward()
            optimizer.step()
        accuracy = test(model, dataloader_test, device)
        print(f'The accuracy on the test dataset for epoch {epoch} is: {accuracy}%')
        accuracy = TestModelCustomBatch(model, path_test, label_test)
        print(f'The accuracy on the test dataset for epoch {epoch} is: {accuracy}%')

In [None]:
def test(model, dataloader_test, device):
  model.eval()
  correct_labels = 0
  total_labels = 0
  with torch.no_grad():
    for data in dataloader_test:
      ims, labels = data
      ims = ims.to(device)
      out = model(ims)
      for i, label in enumerate(labels):
        if label.item() == torch.argmax(out[i]).item():
          correct_labels += 1
      total_labels += labels.shape[0]
  accuracy = correct_labels/total_labels*100
  return accuracy

In [None]:
def TestModelCustomBatch(model,img_paths,expected):
  device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
  count = 0
  total = 0
  for it, img_path in enumerate(img_paths):
      
    img = Image.open(img_path).convert('RGB')
    preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )])
    img = preprocess(img)
    batch_img = torch.unsqueeze(img,0)
    model.eval()
    batch_img = batch_img.to(device)
    with torch.no_grad():
      out = model(batch_img)

    out = out.cpu()
    out = out.squeeze()
    #print(f'{classes[torch.argmax(out).item()]:15} | {expected:15}')
    if classes[torch.argmax(out).item()] == expected[it]:
      count += 1
    total += 1

  accuracy = count/total*100
  return accuracy


Training the classifier head of the Barlow Twins model.

In [None]:
input_shape = trainset[0][0].shape
train_loader = torch.utils.data.DataLoader(trainset, batch_size=100)
test_loader = torch.utils.data.DataLoader(testset, batch_size=25)
train(model, train_loader, test_loader,path_test,label_test,params_to_update, nepochs=10, lr=1e-3)

ep: 0, it: 0, loss : 0.02305
ep: 0, it: 100, loss : 0.01754
ep: 0, it: 200, loss : 0.01165
ep: 0, it: 300, loss : 0.00925
ep: 0, it: 400, loss : 0.00824
The accuracy on the test dataset for epoch 0 is: 80.05%
The accuracy on the test dataset for epoch 0 is: 50.0%
ep: 1, it: 0, loss : 0.00735
ep: 1, it: 100, loss : 0.00675
ep: 1, it: 200, loss : 0.00643
ep: 1, it: 300, loss : 0.00602
ep: 1, it: 400, loss : 0.00592
The accuracy on the test dataset for epoch 1 is: 83.28%
The accuracy on the test dataset for epoch 1 is: 60.0%
ep: 2, it: 0, loss : 0.00559
ep: 2, it: 100, loss : 0.00534
ep: 2, it: 200, loss : 0.00525
ep: 2, it: 300, loss : 0.00504
ep: 2, it: 400, loss : 0.00510
The accuracy on the test dataset for epoch 2 is: 84.55%
The accuracy on the test dataset for epoch 2 is: 60.0%
ep: 3, it: 0, loss : 0.00485
ep: 3, it: 100, loss : 0.00467
ep: 3, it: 200, loss : 0.00466
ep: 3, it: 300, loss : 0.00452
ep: 3, it: 400, loss : 0.00462
The accuracy on the test dataset for epoch 3 is: 85.47%

KeyboardInterrupt: ignored

Defining the function for testing.

In [None]:
def TestModel(model,img_path,expected):
  device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
  #showimg = cv2.imread(img_path)
  #cv2_imshow(showimg)
  img = Image.open(img_path).convert('RGB')
  preprocess = transforms.Compose([
          transforms.Resize(256),
          transforms.CenterCrop(224),
          transforms.ToTensor(),
          transforms.Normalize(
          mean=[0.485, 0.456, 0.406],
          std=[0.229, 0.224, 0.225]
      )])
  img = preprocess(img)
  batch_img = torch.unsqueeze(img,0)
  model.eval()
  batch_img = batch_img.to(device)
  with torch.no_grad():
    out = model(batch_img)

  out = out.cpu()
  out = out.squeeze()
  #print(f'The classification of this image according to Barlow Twins is:{classes[torch.argmax(out).item()]}')
  print(f'{classes[torch.argmax(out).item()]:15} | {expected:15}')
  #print(out)

Performing the tests.

In [None]:
print(f'Barlow Twins    | Expected')
TestModel(model, './test1.jpg', 'Airplane')
TestModel(model, './test2.jpg', 'Airplane')
TestModel(model, './test3.jpg', 'Car')
TestModel(model, './test4.jpg', 'Bird')
TestModel(model, './test5.jpg', 'Cat')
TestModel(model, './test6.jpg', 'Horse')
TestModel(model, './test7.jpg', 'Car')
TestModel(model, './test8.jpg', 'Deer')
TestModel(model, './test9.jpg', 'Ship')
TestModel(model, './test10.jpg', 'Frog')

Barlow Twins    | Expected
Airplane        | Airplane       
Airplane        | Airplane       
Airplane        | Car            
Bird            | Bird           
Bird            | Cat            
Airplane        | Horse          
Car             | Car            
Deer            | Deer           
Airplane        | Ship           
Airplane        | Frog           


#Testing a Barlow Twins trained RESNET50 with a classifier head for TinyImageNet
Since we already have the code for our model above, we will begin with downloading and preprocessing the data of the TinyImageNet dataset.

In [None]:
! git clone https://github.com/seshuad/IMagenet
! ls 'IMagenet/tiny-imagenet-200/'

Cloning into 'IMagenet'...
remote: Enumerating objects: 120594, done.[K
remote: Total 120594 (delta 0), reused 0 (delta 0), pack-reused 120594[K
Receiving objects: 100% (120594/120594), 212.68 MiB | 34.51 MiB/s, done.
Resolving deltas: 100% (1115/1115), done.
Checking out files: 100% (120206/120206), done.
test  train  val  wnids.txt  words.txt


In [None]:
import time
import scipy.ndimage as nd
import matplotlib.pyplot as plt

import numpy as np

path = 'IMagenet/tiny-imagenet-200/'

def get_id_dictionary():
    id_dict = {}
    for i, line in enumerate(open( path + 'wnids.txt', 'r')):
        id_dict[line.replace('\n', '')] = i
    return id_dict
  
def get_class_to_id_dict():
    id_dict = get_id_dictionary()
    all_classes = {}
    result = {}
    for i, line in enumerate(open( path + 'words.txt', 'r')):
        n_id, word = line.split('\t')[:2]
        all_classes[n_id] = word
    for key, value in id_dict.items():
        result[value] = (key, all_classes[key])      
    return result

def get_data(id_dict):
    print('starting loading data')
    train_data, test_data = [], []
    train_labels, test_labels = [], []
    t = time.time()
    for key, value in id_dict.items():
        train_data += [plt.imread( path + 'train/{}/images/{}_{}.JPEG'.format(key, key, str(i))) for i in range(500)]
        train_labels_ = np.array([[0]*200]*500)
        train_labels_[:, value] = 1
        train_labels += train_labels_.tolist()

    for line in open( path + 'val/val_annotations.txt'):
        img_name, class_id = line.split('\t')[:2]
        test_data.append(plt.imread( path + 'val/images/{}'.format(img_name)))
        test_labels_ = np.array([[0]*200])
        test_labels_[0, id_dict[class_id]] = 1
        test_labels += test_labels_.tolist()

    print('finished loading data, in {} seconds'.format(time.time() - t))
    return train_data, train_labels, test_data, test_labels
  
train_data, train_labels, test_data, test_labels = get_data(get_id_dictionary())

starting loading data
finished loading data, in 29.7906277179718 seconds


In [None]:
for data in train_data:
  print(f'The:{len(data)}{len(data[0])}{len(data[0][0])}')

The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643
The:64643


TypeError: ignored

In [None]:
train_data = np.array(train_data)
print( "train data shape: ",  train_data.shape )
print( "train label shape: ", train_labels.shape )
print( "test data shape: ",   test_data.shape )
print( "test_labels.shape: ", test_labels.shape )

  """Entry point for launching an IPython kernel.


ValueError: ignored

#Training a Barlow Twins model

Initializing libraries.

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset

import numpy as np
import matplotlib.pyplot as plt

import argparse
from pathlib import Path
from PIL import Image, ImageOps, ImageFilter
from torch.autograd import Variable
import random
from torchvision import models, datasets, transforms
import cv2
from google.colab.patches import cv2_imshow

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

Defining the Barlow Twins class for the model.

In [None]:
class BarlowTwins(nn.Module):
  def __init__(self, lambd=0.0051):
    super().__init__()
    self.lambd = lambd
    #ResNet50
    self.backbone = torchvision.models.resnet50(zero_init_residual=True)
    self.backbone.fc = nn.Identity()
    #Projector
    sizes = [2048, 8192,8192,8192]
    layers = []
    for i in range(len(sizes) - 2):
      layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=False))
      layers.append(nn.BatchNorm1d(sizes[i + 1]))
      layers.append(nn.ReLU(inplace=True))
    layers.append(nn.Linear(sizes[-2], sizes[-1], bias=False))
    self.projector = nn.Sequential(*layers)
    #NormalizationLayer
    self.bn = nn.BatchNorm1d(sizes[-1], affine=False)

  def forward(self, y1, y2):
    z1 = self.projector(self.backbone(y1))
    z2 = self.projector(self.backbone(y2))

    # empirical cross-correlation matrix
    c = self.bn(z1).T @ self.bn(z2)

    # sum the cross-correlation matrix between all gpus
    #c.div_(self.args.batch_size)
    #torch.distributed.all_reduce(c)

    on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
    off_diag = off_diagonal(c).pow_(2).sum()
    loss = on_diag + self.lambd * off_diag
    return loss

In [None]:
def off_diagonal(x):
    # return a flattened view of the off-diagonal elements of a square matrix
    n, m = x.shape
    assert n == m
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()

Mimicking the optimizer the Facebook Group used for the Barlow Twins training.

In [None]:
class LARS(optim.Optimizer):
  def __init__(self, params, lr, weight_decay=0, momentum=0.9, eta=0.001, weight_decay_filter=False, lars_adaptation_filter=False):
    defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, eta=eta, weight_decay_filter=weight_decay_filter, lars_adaptation_filter=lars_adaptation_filter)
    super().__init__(params, defaults)
  def exclude_bias_and_norm(self, p):
    return p.ndim == 1
  
  @torch.no_grad()
  def step(self):
    for g in self.param_groups:
      for p in g['params']:
        dp = p.grad
        if dp is None:
          continue
        if not g['weight_decay_filter'] or not self.exclude_bias_and_norm(p):
          dp = dp.add(p, alpha=g['weight_decay'])
        if not g['lars_adaptation_filter'] or not self.exclude_bias_and_norm(p):
          param_norm = torch.norm(p)
          update_norm = torch.norm(dp)
          one = torch.ones_like(param_norm)
          q = torch.where(param_norm > 0., torch.where(update_norm > 0, (g['eta'] * param_norm / update_norm), one), one)
          dp = dp.mul(q)
        param_state = self.state[p]
        if 'mu' not in param_state:
          param_state['mu'] = torch.zeros_like(p)
        mu = param_state['mu']
        mu.mul_(g['momentum']).add_(dp)
        p.add_(mu, alpha=-g['lr'])

Defining the transforms.

In [None]:
class GaussianBlur(object):
    def __init__(self, p):
        self.p = p

    def __call__(self, img):
        if random.random() < self.p:
            sigma = random.random() * 1.9 + 0.1
            return img.filter(ImageFilter.GaussianBlur(sigma))
        else:
            return img


class Solarization(object):
    def __init__(self, p):
        self.p = p

    def __call__(self, img):
        if random.random() < self.p:
            return ImageOps.solarize(img)
        else:
            return img


class Transform:
    def __init__(self):
        self.transform = transforms.Compose([
            transforms.RandomResizedCrop(224, interpolation=Image.BICUBIC),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply(
                [transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                        saturation=0.2, hue=0.1)],
                p=0.8
            ),
            transforms.RandomGrayscale(p=0.2),
            GaussianBlur(p=1.0),
            Solarization(p=0.0),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        self.transform_prime = transforms.Compose([
            transforms.RandomResizedCrop(224, interpolation=Image.BICUBIC),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply(
                [transforms.ColorJitter(brightness=0.4, contrast=0.4,
                                        saturation=0.2, hue=0.1)],
                p=0.8
            ),
            transforms.RandomGrayscale(p=0.2),
            GaussianBlur(p=0.1),
            Solarization(p=0.2),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

    def __call__(self, x):
        y1 = self.transform(x)
        y2 = self.transform_prime(x)
        return y1, y2

Initializing the model, the parameters and the optimizer.

In [None]:
model = BarlowTwins()
#model = nn.SyncBatchNorm.convert_sync_batchnorm(model)

param_weights = []
param_biases = []

for param in model.parameters():
  if param.ndim == 1:
    param_biases.append(param)
  else:
    param_weights.append(param)

parameters = [{'params': param_weights}, {'params': param_biases}]
optimizer = LARS(parameters, lr=0, weight_decay=1e-6, weight_decay_filter=True, lars_adaptation_filter=True)

#DO SOMETHING LIKE THIS IF WE WANT TO WORK ON MULTIPLE GPUs.
#model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device])



IF WE WANT TO USE THE CHECKPOINT:
(Otherwise, don't use)

In [None]:
ckpt = torch.load('/content/drive/MyDrive/MIR/1stSemester/DeepLearning/BarlowTwins/checkpoint.pth')

#Changing nomenclature to match with model.
for key in list(ckpt['model'].keys()):
  ckpt['model'][key.replace('module.backbone', 'backbone').replace('module.projector','projector').replace('module.bn','bn')] = ckpt['model'].pop(key)

start_epoch = ckpt['epoch']
model.load_state_dict(ckpt['model'])
optimizer.load_state_dict(ckpt['optimizer'])

### To read the dictionary for the weights in the model ###
#for param_tensor in ckpt['model']:
#  print(param_tensor, "\t", ckpt['model'][param_tensor].size())

Preparing the dataset:

In [None]:
!gdown --id 1LOM-2A1BSLaFjCY2EEK3DA2Lo37rNw-7
!unzip -oq underwater_imagenet.zip

Downloading...
From: https://drive.google.com/uc?id=1LOM-2A1BSLaFjCY2EEK3DA2Lo37rNw-7
To: /content/underwater_imagenet.zip
100% 449M/449M [00:02<00:00, 168MB/s]


In [None]:
from PIL import Image as im
from tqdm import tqdm

In [None]:
class CustomDataset(Dataset):
    def __init__(self, imgs, transform=None, target_transform=None):
      self.imgs = imgs
      self.transform = transform

    def __len__(self):
      return len(self.imgs)

    def __getitem__(self, idx):
      sample = im.fromarray(self.imgs[idx])
      #sample = self.imgs[idx]
      if self.transform:
        y1, y2 = self.transform(sample)
      return y1, y2

In [None]:
import glob
filelist = glob.glob('/content/underwater_imagenet/trainA/*.jpg')
train_imgs = np.array([np.array(Image.open(fname)) for fname in filelist])

In [None]:
transform = Transform()
train_set = CustomDataset(train_imgs,transform=transform)
train_loader = torch.utils.data.DataLoader(train_set,batch_size=16)

  "Argument interpolation should be of type InterpolationMode instead of int. "


In [None]:
episodes = 5
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
#torch.distributed.init_process_group(backend='nccl',rank=)
#model = model.to(device)
#optimizer = optimizer.to(device)
for episode in tqdm(range(episodes)):
  for it, (y1, y2) in enumerate(train_loader):
    model = model.to(device)
    y1 = y1.to(device)
    y2 = y2.to(device)
    optimizer.zero_grad()
    loss = model(y1,y2)
    loss.backward()
    model = model.cpu()
    optimizer.step()

 20%|██        | 1/5 [19:26<1:17:45, 1166.33s/it]


KeyboardInterrupt: ignored