In [2]:
from ResNet50Classifier.py import Resnet_classifier
from SupCon import Embedding_Network

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import torchvision
import torchvision.transforms as transforms
import torch
import torch.nn as nn
import torch.utils.data as data
import torch.nn as nn
import torch.optim as optim
import os
import numpy as np
import pandas as pd
from tqdm import tqdm

In [4]:
class WhaleDataset_eval(data.Dataset):
    def __init__(self, csv):
        
        self.df = pd.read_csv(csv)

        self.groups = self.df.groupby('individual_id').groups
        self.keys = list(self.groups.keys())
        self.label_encoder = {}
        for i in range(len(self.keys)):
            self.label_encoder[self.keys[i]] = i
        
       
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):         
        img = torchvision.io.read_image("train_images_cropped/" + self.df["image"].iloc[index]).float()

        
        if img.shape[0] == 1:
            img = torch.cat((img,img,img))

        label = self.label_encoder[self.df["individual_id"].iloc[index]]
    
        return {
            'image': img,
            'id' : label
        }

In [5]:
#Define a simple MLP for predictions
class MLP(nn.Module):
    def __init__(self, in_size, out_size):
        super().__init__()
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(in_size, in_size),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(in_size, in_size),
            nn.ReLU(inplace=True),
            nn.Linear(in_size, out_size)
        )     
    def forward(self,x):
        return self.classifier(x)
    

In [8]:
def train_mlp(model, embedding_model,train_loader, optimizer, criterion,device):
      model.train()
      for batch_idx, _data in tqdm(enumerate(train_loader)):

                image = _data["image"]
                
                target = _data["id"]

                image, target = image.to(device), target.to(device)
                with torch.no_grad():
                    vect = embedding_model(image)
                    
                optimizer.zero_grad()
                
                pred = model(vect)
                loss = criterion(pred,target)
                loss.backward()
                optimizer.step()

In [9]:
def evaluate(model, embedding_model, test_loader, criterion,device):
    avg_loss = 0
    l = len(test_loader)
    acc = 0
    model.eval()
    with torch.no_grad():
        for batch_idx, _data in tqdm(enumerate(test_loader)):
            image = _data["image"]
                
            target = _data["id"]
            
            image, target = image.to(device), target.to(device)
            vect = embedding_model(image)
            pred = model(vect)
                   
            loss = criterion(pred,target)
            avg_loss += loss/l

            #compute accuracy
            _, predicted_ids = torch.max(pred,1)
            #print((predicted_ids == target).sum().item())
            acc += (predicted_ids == target).sum().item()/l/1024
    print(acc)
    return acc

In [10]:
def main(model, dataset, learning_rate = 5e-3, batch_size = 1024, epochs = 10):
    hparams = {
        "learning_rate": learning_rate,
        "batch_size": batch_size,
        "epochs": epochs
    }
    
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    conv_model = torchvision.models.resnet18(pretrained = False)
    embedding_model = Embedding_Network(conv_model)
    embedding_model.load_state_dict(torch.load("res_net_supcon_1.pt", map_location = torch.device('cuda'))) 
    embedding_model.to(device)
    embedding_model.eval()
    
    model.to(device)
    _, test_dataset = torch.utils.data.random_split(dataset, [len(dataset) -5000,5000])
    #test_dataset = WhaleDataset(csv, image_folder, classes_map)
    kwargs = {'num_workers': 6, 'pin_memory': True} if use_cuda else {}
    train_loader = data.DataLoader(dataset=dataset,
                                batch_size=hparams['batch_size'],
                                shuffle=True,
                                
                                **kwargs)
    test_loader= data.DataLoader(dataset=test_dataset,
                                batch_size=hparams['batch_size'],
                                shuffle=True,
                                
                                **kwargs)
    """
    test_loader = data.DataLoader(dataset=test_dataset,
                                batch_size=hparams['batch_size'],
                                shuffle=False,
                                collate_fn=lambda x: data_processing(x, 'valid'),
                                **kwargs)
    """
    
    
    print('Num Model Parameters', sum([param.nelement() for param in model.parameters()]))
    
    
    optimizer = optim.Adam(model.parameters(), hparams['learning_rate'])
    loss = nn.CrossEntropyLoss()
    print("it # per epoch : " + str(len(train_loader)))
    for i in range(1, epochs+1):
        train_mlp(model, embedding_model,train_loader,optimizer,loss,device)
        evaluate(model, embedding_model, test_loader,loss,device)
    torch.save(model.state_dict(), "MLP.pt")

In [None]:
class Resnet_Classifier(nn.Module):
    def __init__(self, in_size, out_size):
        super().__init__()
        self.conv_model = torchvision.models.resnet50(pretrained = True)
        self.classifier =  nn.Sequential(
            nn.Dropout(),
            nn.Linear(in_size, in_size),
            nn.Tanh(),
            nn.Dropout(),
            nn.Linear(in_size, in_size),
            nn.Tanh(),
            nn.Linear(in_size, out_size)
        )     
        self.linear_classifier = nn.Sequential(nn.Linear(in_size,out_size), nn.Softmax(dim=1))
        self.linear = nn.Linear(in_size,out_size)
        self.softmax = nn.Softmax(dim=1)
    def forward(self, batch_data):
        x = self.conv_model(batch_data)
        x = self.classifier(x)
        return x

In [11]:
dataset = WhaleDataset_eval("train_corrected.csv")
print(len(dataset.keys))
model = MLP(128,len(dataset.keys))
#model = Resnet_Classifier(1000,len(dataset.keys))
main(model, dataset)

15587
Num Model Parameters 2043747
it # per epoch : 50


50it [00:43,  1.15it/s]
5it [00:07,  1.46s/it]

0.0042968749999999995



50it [00:39,  1.25it/s]
5it [00:06,  1.34s/it]

0.004296875



50it [00:40,  1.24it/s]
5it [00:07,  1.46s/it]

0.004296875



50it [00:39,  1.27it/s]
5it [00:06,  1.28s/it]

0.0041015625



50it [00:39,  1.25it/s]
5it [00:07,  1.40s/it]

0.0039062499999999996



50it [00:40,  1.24it/s]
5it [00:06,  1.34s/it]

0.00546875



50it [00:40,  1.23it/s]
5it [00:07,  1.41s/it]

0.005078124999999999



50it [00:40,  1.23it/s]
5it [00:06,  1.33s/it]

0.0048828125



50it [00:40,  1.22it/s]
5it [00:06,  1.34s/it]

0.0062499999999999995



50it [00:39,  1.27it/s]
5it [00:06,  1.32s/it]

0.0062499999999999995



