In [1]:
import Model as md
import Resnet as res
import data as dta
import lars_optimizer as lars

import torch
import torch.nn as nn
import torch.optim as optim
from pytorch_metric_learning import losses
from torchvision import transforms as T
from torch.utils.data import DataLoader
from absl import app
from absl import flags
from absl import logging
import os
from tqdm import tqdm

from torch.cuda import amp
import pandas as pd
import math

In [2]:
#device = 'cuda' if torch.cuda.is_available() else 'cpu'
encoder = res.ResNet50()
encoder = encoder.cuda()
proj_head = md.Projection_Head(encoder.representation_dim).cuda()

In [3]:
optimizer = optim.SGD(list(encoder.parameters()) + list(proj_head.parameters()), lr=math.sqrt(128)*0.075, weight_decay=1e-6,momentum=0.9)
#optimizer = lars.LARS(base_optimizer, trust_coefficient=0.001)
#optimizer = optim.AdamW(list(encoder.parameters()) + list(proj_head.parameters()), lr=0.25, weight_decay=1e-4)
ntxent_loss = losses.NTXentLoss(temperature=0.1)

In [4]:
#Transformations
transf = T.Compose([
    T.CenterCrop(400),
    T.Resize(112),
    T.RandomVerticalFlip(),
    T.RandomHorizontalFlip(),
    T.RandomApply(
    [T.ColorJitter(brightness=(0.65,1.5), contrast=(0.65,1.3), saturation=(0.65,1.3), hue=0.2),
    T.RandomResizedCrop(112, scale=(0.2, 1.0))], p=0.8),
    #T.Resize(330),
    T.RandomGrayscale(p=0.2),
    T.RandomHorizontalFlip(p=0.5),
    T.RandomApply(
    [T.GaussianBlur(kernel_size=3, sigma=(0.1, 5.1))], p=0.5),
    T.ToTensor(),
])

In [5]:
astro_ds = dta.AstroDataset('nair_unbalanced_train.csv', 'imagenes_clasificadas_nair/', transform=transf)
dataset_astro = DataLoader(astro_ds,batch_size=128, shuffle=True,num_workers=6)

In [6]:
astro_ds.__len__()*100//(256+1)

3642

In [7]:
astro_ds.__len__() * 100 // 256 + 1

3657

In [8]:
(9361 * 5 // 256 + 1)

183

In [9]:
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, (astro_ds.__len__() * 5 // 128 + 1))

In [10]:
scaler = amp.GradScaler()

In [11]:
def save_model(encoder, projection_head, epoch_number, optimizer, scheduler):
    torch.save({
        'encoder': encoder.state_dict(),
        'projection_head': projection_head.state_dict(),
        'epoch': epoch,
        'optimizer': optimizer.state_dict(),
        'scheduler': scheduler.state_dict()
    }, 'model/model.pt')

In [12]:
df = pd.DataFrame(columns=['Epoch','ContrastiveLoss','ContrastiveAccuracy'])

In [13]:
for epoch in range(0, 5):
    acc_epoc = 0
    epoch_loss = 0
    #use tqdm
    tqdm_loop = tqdm(enumerate(dataset_astro), total=len(dataset_astro), leave=True)
    encoder.train()
    proj_head.train()
    for batch_idx, data in tqdm_loop:
        data = data.cuda()
        transformed_img1, transformed_img2 = torch.split(data, 3, dim=1)
        transformed_img1, transformed_img2 = transformed_img1.cuda(), transformed_img2.cuda()
        inputs = torch.cat((transformed_img1,transformed_img2),0)
        optimizer.zero_grad()
        with amp.autocast():
            projection = proj_head(encoder(inputs))
            pseudolabels = torch.arange(transformed_img1.size(0)).cuda()
            pseudolabels = torch.cat([pseudolabels, pseudolabels], dim=0)
            loss = ntxent_loss(projection, pseudolabels)
            hiddens = torch.split(projection,[projection.size(0)//2,projection.size(0)//2],dim=0)
            #print(hiddens[0].shape)
            logits = torch.matmul(hiddens[0], torch.transpose(hiddens[1], 0,1))/0.1
            #print(logits.shape)
            contrastive_acc = torch.argmax(logits,dim=1)
            contrastive_acc = torch.mean(torch.eq(pseudolabels[:projection.size(0)//2], contrastive_acc).float())
            #print(contrastive_acc)
            acc_epoc += contrastive_acc.item()
            epoch_loss += loss.item()
        scaler.scale(loss).backward()
        #loss.backward()
        scaler.step(optimizer)
        scaler.update()
        #optimizer.step()
        #epoch_loss += 
        #print(loss.item())
        scheduler.step()
        #update progress bar
        tqdm_loop.set_description(f'Epoch [{epoch}/{5}]')
        tqdm_loop.set_postfix(loss = loss.item())
    #save_model(encoder, proj_head, epoch, optimizer)
    acc_epoc = acc_epoc/(batch_idx+1)
    epoch_loss /= (batch_idx+1)
    d_list = [epoch, epoch_loss, acc_epoc]
    df.loc[len(df), :] = d_list
    print('Epoch: {}, Loss: {}, Contrastive Accuracy: {}'.format(epoch, epoch_loss, acc_epoc*100))

Epoch [0/5]: 100%|██████████████████████████████████████████████████████████| 74/74 [00:24<00:00,  3.05it/s, loss=4.17]


Epoch: 0, Loss: nan, Contrastive Accuracy: 1.008545310312026


Epoch [1/5]: 100%|██████████████████████████████████████████████████████████| 74/74 [00:23<00:00,  3.13it/s, loss=4.17]


Epoch: 1, Loss: nan, Contrastive Accuracy: 0.7706925675675675


Epoch [2/5]: 100%|██████████████████████████████████████████████████████████| 74/74 [00:23<00:00,  3.11it/s, loss=4.16]


Epoch: 2, Loss: nan, Contrastive Accuracy: 0.8501838238255397


Epoch [3/5]: 100%|███████████████████████████████████████████████████████████| 74/74 [00:24<00:00,  3.05it/s, loss=nan]


Epoch: 3, Loss: nan, Contrastive Accuracy: 0.8501838238255397


Epoch [4/5]:  24%|██████████████▎                                            | 18/74 [00:07<00:22,  2.46it/s, loss=nan]


KeyboardInterrupt: 

In [None]:
df

In [None]:
import math
math.sqrt(128)