In [3]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm 
from modules.datasets import MultiLabelDataset,ContrastiveDataset
from modules.siamese import SiameseNetwork
import torchdatasets as td
from pytorch_lightning.profiler import SimpleProfiler
import pytorch_lightning as pl

import warnings
warnings.filterwarnings("ignore", ".*does not have many workers.*") # we want to run Single-Core -> Ignore this warning

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [4]:
dataset = MultiLabelDataset("./Carparts",cache_in_ram=True,cache_path="./cache")
contrastiveDataset = ContrastiveDataset(dataset)
len(contrastiveDataset)

13809

In [5]:
train_size = int(0.8 * len(contrastiveDataset))
test_size = len(contrastiveDataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(contrastiveDataset, [train_size, test_size])
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128)

In [6]:
import torchvision.models as models
from torchsummary import summary

model = models.resnet50(pretrained=True)
# for param in model.parameters():
#     param.requires_grad = False
model.fc = nn.Sequential(
    nn.Linear(model.fc.in_features, 512),
    nn.ReLU(inplace=True),
    nn.Linear(512, 256),
    nn.ReLU(inplace=True),
    nn.Linear(256, 128),
)
summary(model.to(DEVICE), (3, 224, 224))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 112, 112]           9,408
       BatchNorm2d-2         [-1, 64, 112, 112]             128
              ReLU-3         [-1, 64, 112, 112]               0
         MaxPool2d-4           [-1, 64, 56, 56]               0
            Conv2d-5           [-1, 64, 56, 56]           4,096
       BatchNorm2d-6           [-1, 64, 56, 56]             128
              ReLU-7           [-1, 64, 56, 56]               0
            Conv2d-8           [-1, 64, 56, 56]          36,864
       BatchNorm2d-9           [-1, 64, 56, 56]             128
             ReLU-10           [-1, 64, 56, 56]               0
           Conv2d-11          [-1, 256, 56, 56]          16,384
      BatchNorm2d-12          [-1, 256, 56, 56]             512
           Conv2d-13          [-1, 256, 56, 56]          16,384
      BatchNorm2d-14          [-1, 256,

In [7]:
siamese_model = SiameseNetwork(model)

trainer = pl.Trainer(accelerator="gpu", max_epochs=5, precision="bf16",benchmark=True,fast_dev_run=False)
trainer.fit(model=siamese_model, train_dataloaders=train_loader, val_dataloaders=test_loader)


Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type            | Params
----------------------------------------------
0 | model     | ResNet          | 24.7 M
1 | criterion | ContrastiveLoss | 0     
----------------------------------------------
24.7 M    Trainable params
0         Non-trainable params
24.7 M    Total params
98.885    Total estimated model params size (MB)


Epoch 4: 100%|██████████| 109/109 [11:55<00:00,  6.57s/it, loss=0.255, v_num=4]  


In [12]:
for subset in dataset.datasets:
    print(siamese_model(dataset.datasets[subset][0][0].reshape(1,3,224,224)))

tensor([[ 0.3259, -1.0109, -0.2502,  0.7671,  0.0248, -0.2950,  0.4633, -0.1364,
          0.2257,  0.0041, -0.3134, -0.0305,  0.4244,  0.0252,  0.0759,  0.0993,
          0.8282,  0.0491, -0.1028, -0.6206,  0.4152, -0.0645, -0.3571,  0.8471,
         -0.0549, -0.0816,  0.3380, -0.3665,  0.5714,  0.5733, -0.6821, -0.2083,
         -0.3417, -0.0619,  0.9279,  0.4821,  0.1775, -0.1955,  0.3373,  0.3271,
          0.1928, -0.1052,  0.1878,  0.1877,  0.1110, -0.2217,  0.3067, -0.7441,
         -1.2032,  1.1072, -0.1086, -0.5298, -0.2465, -0.1587, -0.0512, -0.7199,
         -0.0470,  0.0708, -0.6328, -0.0833,  0.6278,  0.3888, -0.1535, -0.2640,
         -0.4642, -0.4565, -0.1586, -0.2788,  0.6138,  0.0537, -0.6038,  0.4900,
         -0.6109, -0.5883,  0.1327,  0.8794, -0.2659, -0.3363,  0.4326, -0.0743,
          0.0701,  0.4669, -0.2329, -1.0083, -0.5487, -0.3766,  0.3819,  0.6669,
         -0.2136,  0.2426, -0.2461, -0.4134, -0.1024, -0.4160, -0.3675, -0.4726,
          0.6631,  0.3613, -