In [1]:
pip install pytorch-lightning lightly

Note: you may need to restart the kernel to use updated packages.


In [1]:
import sys
sys.path.append('../')
import torch
import PIL
import pytorch_lightning as pl
import torchvision
from les.les import les_dist_comp, les_desc_comp, _build_graph
from les_pytorch.les import LES
import numpy as np
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision.models import alexnet
from lightly.models.modules.heads import SimCLRProjectionHead
from matplotlib import pyplot as plt
import lightly

In [14]:
domain='Painting'
if domain.lower() == 'imagenet':
    data_path = f'../../../../talmodata-nfs/datasets/TinyImageNet/train'
elif domain.lower() == 'cremi':
    data_path = f'../../../../talmodata-nfs/datasets/cremi/jpegs'
else:
    data_path = f'../../../../talmodata-nfs/datasets/DomainNet/{domain.lower()}'
print(f'Data path is {data_path}')
example_dataset=lightly.data.LightlyDataset(data_path)
test_collate_fn = lightly.data.SimCLRCollateFunction(input_size=64,gaussian_blur=0.0,cj_prob=0.0)
dataloader_simclr = DataLoader(
                        example_dataset,
                        batch_size=512,
                        shuffle=False,
                        collate_fn=test_collate_fn,
                        drop_last=True,
                        num_workers=8,
                        pin_memory=False
                    )
print(f'Dataset size: {len(example_dataset)}')

Data path is ../../../../talmodata-nfs/datasets/DomainNet/painting
Dataset size: 75759


In [3]:
iterator = iter(dataloader_simclr)
batch,labels,files = next(iterator)
aug1 = batch[0]
aug2 = batch[1]
# fig,ax = plt.subplots(5,3,figsize=(64,64),facecolor='w')
# for i,file in zip(range(5),files):
#     ax[i,0].imshow(PIL.Image.open(f'{data_path}/{file}'))
# ax[0,0].set_title('Original')
# for i in range(5):
#     ax[i,1].imshow(aug1[i].permute(1,2,0))
# ax[0,1].set_title('View 1')
# for i in range(5):
#     ax[i,2].imshow(aug2[i].permute(1,2,0))
# ax[0,2].set_title('View 2')

In [15]:
class LESclrModel(pl.LightningModule):
    def __init__(self,
                 dataset_path,
                 backbone = torchvision.models.alexnet(pretrained=False),
                 transform = torchvision.transforms.Resize((64,64)),
                 batch_size = 512,
                 temp = 0.1, 
                 learning_rate = 1e-2,  
                 momentum = 0.9, 
                 embedding_dim = 128,
                 input_size = 64,
                 weight_decay = 1e-6):
        super().__init__()
        hidden_dim = 9216
        self.backbone = torch.nn.Sequential(*list(backbone.children())[:-1])
        self.dataset_path = dataset_path
        self.transform = transform
        self.batch_size = batch_size
        self.temp = temp
        self.lr = learning_rate
        self.momentum = momentum
        self.weight_decay = weight_decay
        self.projection_head = SimCLRProjectionHead(hidden_dim, hidden_dim, embedding_dim)
        self.criterion = LES()
        self.collate_fn = lightly.data.SimCLRCollateFunction(input_size=input_size,gaussian_blur=0.0,cj_prob=0.0)
        self.save_hyperparameters()
    def forward(self, x):
        h=self.backbone(x).flatten(start_dim=1)
        z = self.projection_head(h)
        return z

    def training_step(self, batch, batch_idx):
        (x0, x1), _, _ = batch
        z0 = self.forward(x0)
        z1 = self.forward(x1)
        print(z0.device)
        print(z1.device)
        loss = self.criterion(z0, z1)
        self.log("train_loss_les", loss)
        return loss

    def configure_optimizers(self):
        optim = torch.optim.SGD(params=self.parameters(),lr=self.lr,momentum=self.momentum,weight_decay=self.weight_decay)
        scheduler = None #torch.optim.lr_scheduler.ReduceLROnPlateau(optim, mode='min', factor=0.5, patience=10, threshold=0.1, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08, verbose=True)
        monitor =  {"scheduler": scheduler, "monitor": "train_loss_ssl",'interval':'epoch',"frequency":10}
        return optim#],[monitor]

    def train_dataloader(self):
        #normalize={'mean':[0, 0, 0],'std':[1, 1, 1]})
        dataset = lightly.data.LightlyDataset(input_dir=self.dataset_path,transform=self.transform)
        return DataLoader(
                        dataset,
                        batch_size=self.batch_size,
                        shuffle=True,
                        collate_fn=self.collate_fn,
                        drop_last=True,
                        num_workers=8,
                        pin_memory=False
                    )

In [16]:
LESclr_model = LESclrModel(dataset_path = data_path,
                                              transform = None,
                                              batch_size = 1024,
                                              temp = 0.1,
                                              learning_rate = 4.8,
                                              momentum = 0.9,
                                              embedding_dim = 128,
                                              input_size = 64,
                                              weight_decay =1e-6,
                                              )
LESclr_model

LESclrModel(
  (backbone): Sequential(
    (0): Sequential(
      (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
      (1): ReLU(inplace=True)
      (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
      (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (4): ReLU(inplace=True)
      (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
      (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (7): ReLU(inplace=True)
      (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (9): ReLU(inplace=True)
      (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (11): ReLU(inplace=True)
      (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (1): AdaptiveAvgPool2d(output_size=(6, 6))
  )
  (projection_head): SimCLRProjectionHead(
    (layers): Sequential(
      (0):

In [None]:
iterator = iter(dataloader_simclr)
for batches in iterator:
    batch,labels,files = batches
    aug1 = batch[0]
    aug2 = batch[1]
    embed1 = LESclr_model(aug1).to(torch.float64)
    les_embed1 = embed1.detach().numpy()
    les_embed1 = les_desc_comp(les_embed1)
    embed2 = LESclr_model(aug2).to(torch.float64)
    les_embed2 = embed2.detach().numpy()
    les_embed2 = les_desc_comp(les_embed2)
    print(les_dist_comp(les_embed1,les_embed2))
    les = LES()
    print(les(embed1,embed2))
    # les(embed1.to(torch.float64),embed2.to(torch.float64))
len(batch)

In [13]:
len(aug1)

512

In [None]:
x1 = torch.randn(size=(2,128),dtype=torch.float32)
x2 = torch.randn(size=(2,128),dtype=torch.float32)
LES()(x1,x2)

In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device == "cpu":
    print('GPU is not available. Using CPU')
else: print('Using GPU')

Using GPU


In [12]:
trainer = pl.Trainer(gpus=1, 
                     strategy="dp", 
                     max_epochs=10, 
                     log_every_n_steps=1)
trainer.fit(LESclr_model)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
