Clone the repository

In [None]:
!git clone -b dev https://github.com/VimsLab/scr.git

Cloning into 'scr'...
remote: Enumerating objects: 233, done.[K
remote: Counting objects: 100% (233/233), done.[K
remote: Compressing objects: 100% (207/207), done.[K
remote: Total 233 (delta 29), reused 219 (delta 18), pack-reused 0[K
Receiving objects: 100% (233/233), 20.93 MiB | 14.83 MiB/s, done.
Resolving deltas: 100% (29/29), done.


In [None]:
%cd scr/dent/

/content/scr/dent


Imports

In [None]:
import os
import sys
import math
import torch
import pickle
import random
import numpy as np
import torch.nn as nn
import concurrent.futures
import torch.nn.functional as F
import torch.distributed as dist
import torch.multiprocessing as mp
import torchvision.transforms as transforms

from PIL import Image
from glob import glob
from tqdm import tqdm
from pathlib import Path
from torch.nn.parallel import DataParallel
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import Dataset, DataLoader
from torchvision.models import resnet50, ResNet50_Weights
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP

from pretrain import *

In [None]:
def pretrainer(rank, world_size, root, dataroot, phases=['sample', 'sample'], resume=False, save=False):
    setup(rank, world_size)

    num_epochs = 5
    batch_size = 64 #// world_size
    
    tx_dict = tx()
    train_loader, train_sampler = get_dataset(world_size, rank, dataroot, 
                                            phase=phases[0], lim=100, 
                                            transform=tx_dict['train'], 
                                            batch_size=batch_size, num_workers=2)
    val_loader, val_sampler = get_dataset(world_size, rank, dataroot, 
                                        phase=phases[1], lim=50, 
                                        transform=tx_dict['val'], 
                                        batch_size=batch_size, num_workers=2)

    # create model and optimizer
    encoder = Encoder(hidden_dim=256, num_encoder_layers=6, nheads=8)
    siamese_net = SiameseNetwork(encoder).to(rank)

    # Wrap the model with DistributedDataParallel
    siamese_net = DDP(siamese_net, device_ids=[rank], find_unused_parameters=False)

    optimizer = torch.optim.Adam(siamese_net.parameters(), lr=0.001)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)


    best_accuracy = 0
    start_epoch = 0

    if resume:
        ckptfile = root + resume + '.pth'
        ckpts = torch.load(ckptfile, map_location='cpu')
        siamese_net.load_state_dict(ckpts['model_state_dict'])
        optimizer.load_state_dict(ckpts['optimizer_state_dict'])
        start_epoch = ckpts['epoch']+1
        best_accuracy = ckpts['best_val_acc']

        if rank == 0:
            print('Resuming training from epoch {}. Loaded weights from {}. Last best accuracy was {}'
                .format(start_epoch, ckptfile, best_accuracy))


    # Train the network
    for epoch in range(start_epoch, num_epochs):
        train_sampler.set_epoch(epoch)
        train_epoch(rank, siamese_net, optimizer, train_loader, epoch, num_epochs, running_loss=0)
        
        # Update the learning rate
        lr_scheduler.step()

        if rank==0:
            vloss, acc = validate(rank, siamese_net, val_loader)

            if acc>=best_accuracy:
                best_accuracy = acc
                # save_path = root + 'epoch' + str(epoch) + 'best_pretrainer.pth'
                save_path = root + 'best_pretrainer.pth'
            else:
                save_path = root + 'last_pretrainer.pth'
            
            checkpoint = {
                    'epoch': epoch,
                    'model_state_dict': siamese_net.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'best_val_acc': best_accuracy,
                }
            if save:
              torch.save(checkpoint, save_path)
              print('\nSaved weights to', save_path)

    # Clean up the process group
    cleanup()            


In [None]:
root = './'
dataroot = '../sample_pkl/'
world_size = 1
pretrainer(0,world_size, root, dataroot, ['sample', 'sample'])


Loading positive and negative pairs from pickled lists of sample
sample dataset has 96 positive pairs and 1470 Negative pairs.
Ratio of negative to positive samples = 15.3125

Loading positive and negative pairs from pickled lists of sample
sample dataset has 96 positive pairs and 1470 Negative pairs.
Ratio of negative to positive samples = 15.3125

                Device                 Epoch               GPU Mem                  Loss


                     0                   0/4                 8.56G              0.007569: 100%|██████████| 25/25 [00:23<00:00,  1.06it/s]          


                Device               Correct              Accuracy                  Loss



                     0             1470/1566                0.9387              0.007789: 100%|██████████| 25/25 [00:11<00:00,  2.17it/s]



                Device                 Epoch               GPU Mem                  Loss


                     0                   1/4                 9.83G              0.007817: 100%|██████████| 25/25 [00:21<00:00,  1.16it/s]          


                Device               Correct              Accuracy                  Loss



                     0             1470/1566                0.9387               0.00779: 100%|██████████| 25/25 [00:12<00:00,  1.99it/s]



                Device                 Epoch               GPU Mem                  Loss


                     0                   2/4                 9.83G              0.007997: 100%|██████████| 25/25 [00:21<00:00,  1.15it/s]          


                Device               Correct              Accuracy                  Loss



                     0             1470/1566                0.9387              0.007801: 100%|██████████| 25/25 [00:12<00:00,  2.05it/s]



                Device                 Epoch               GPU Mem                  Loss


                     0                   3/4                 9.83G              0.007464: 100%|██████████| 25/25 [00:22<00:00,  1.12it/s]          


                Device               Correct              Accuracy                  Loss



                     0             1470/1566                0.9387              0.007765: 100%|██████████| 25/25 [00:12<00:00,  2.00it/s]



                Device                 Epoch               GPU Mem                  Loss


                     0                   4/4                 9.83G               0.00778: 100%|██████████| 25/25 [00:22<00:00,  1.09it/s]          


                Device               Correct              Accuracy                  Loss



                     0             1469/1566                0.9381              0.007758: 100%|██████████| 25/25 [00:11<00:00,  2.16it/s]
