In [1]:
import numpy as np
from torch.utils.data import Dataset
from PIL import Image
from torchvision import models
import glob
import random
import torch
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import torchvision.transforms as T
from tqdm import tqdm
from sklearn.metrics import precision_recall_fscore_support

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class Ocelot(Dataset):
    def __init__(self,valid=False):
        self.valid = valid
        self.patches0 = list(glob.glob('/workspace/jay/DDP/Ocelot/cell_patches/patches/0_*.png'))+list(glob.glob('/workspace/jay/DDP/Ocelot/cell_patches/patches/2_*.png'))
        self.patches1 = glob.glob('/workspace/jay/DDP/Ocelot/cell_patches/patches/1_*.png')
        val = ['037', '038', '055', '056', '078', '087', '095', '108', '114','141', '143', '146', '148', '153', '180', '184', '187', '188',
       '189', '194', '195', '202', '213', '223', '238', '242', '276','285', '287', '291', '300', '319', '333', '349', '358', '385',
       '386', '388', '390', '399','013', '016', '024', '026', '027', '028', '039', '047', '049','369', '374', '393', '398',
       '061', '066', '067', '101', '107', '116', '121', '140', '155','159', '163', '164', '166', '167', '172', '177', '190', '196',
       '203', '211', '237', '312', '315', '318', '337', '350', '362']             
        self.patches_0 = [x for x in self.patches0 if x.split('/')[-1].split('_')[1] not in val]
        self.patches_1 = [x for x in self.patches1 if x.split('/')[-1].split('_')[1] not in val]
#         print(len(self.patches0)+len(self.patches1))
#         print(len(self.patches_0)+len(self.patches_1))
        
        self.valid_path = [x for x in self.patches0 if x.split('/')[-1].split('_')[1] in val]+[x for x in self.patches1 if x.split('/')[-1].split('_')[1] in val]
        self.valid_paths = []
        for i in range(len(self.valid_path)):
            if self.valid_path[i].split('/')[-1].split('_')[0]==2:
                pass
            else:
                self.valid_paths.append(self.valid_path[i])
        del self.valid_path
        
    def __len__(self): 
        if self.valid==False:
            return 100000
        else:
            return len(self.valid_paths)

    def __getitem__(self, idx):
        if self.valid==False:
            idx = random.randint(0,1)
            if idx==0:
                file = random.choice(self.patches_0)
                label = torch.Tensor([0.0])
            else:
                file = random.choice(self.patches_1)
                label = torch.Tensor([1.0])
        else:
            file = self.valid_paths[idx]
            label = torch.Tensor([float(file.split('/')[-1].split('_')[0])])
            
        name = file.split('/')[-1].split('.')[0].split('_')[1]
        cell = Image.open(file)
        
        if self.valid==False:
            if random.uniform(0, 1)>0.5:
                cell = T.functional.hflip(cell)
            rot = random.uniform(0,1)
            if rot<0.5 and rot>=0.25:
                cell = T.functional.rotate(cell,angle=90)
            elif rot>=0.5 and rot<0.75:
                cell = T.functional.rotate(cell,angle=180)
            elif rot>=0.75 and rot<=1:
                cell = T.functional.rotate(cell,angle=270)
        
        cell = np.array(cell)
        cell = cell / 255
        cell = cell - 0.5
        cell = torch.Tensor(np.moveaxis(cell, -1, 0))
    
        return cell,label,name

In [3]:
class clsmodel(torch.nn.Module):
    def __init__(self):
        super(clsmodel,self).__init__()
        resnet = models.resnet18(pretrained=True)
        self.linear1 = torch.nn.Linear(resnet.fc.in_features,1024)
        self.linear2 = torch.nn.Linear(1024,1)
        self.sigmoid = torch.nn.Sigmoid()
        self.relu = torch.nn.ReLU()
        #self.dropout = torch.nn.Dropout(0.25)
        self.model = torch.nn.Sequential(*list(resnet.children())[:-1])
        #deactivate_requires_grad(self.backbone)
    def forward(self,x):
        features = self.model(x).flatten(start_dim=1)
        x = self.relu(self.linear1(features))
        x = self.linear2(x)
        pred = self.sigmoid(x)
        return pred

In [4]:
device = 'cuda:4'
batch_size = 256
model = clsmodel().to(device)
ds = Ocelot()
val_ds = Ocelot(valid=True)
dl = DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=4)
val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2)



In [5]:
ce_loss = torch.nn.BCELoss().to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.999), eps=1e-08,weight_decay=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
num_epochs = 50

best_loss = float('inf')

In [6]:
for epoch in (range(num_epochs)):
    total_loss = 0
    model = model.train()
    for cell,label,_ in (dl):
        optimizer.zero_grad()
        image = cell.to(device)
        label = label.to(device)
        out = model(image)
        loss = ce_loss(out, label)
        total_loss += loss
        loss.backward()
        optimizer.step()
    
    model = model.eval()
    with torch.no_grad():
        total_val_loss = 0
        y_truef = np.array([])
        y_predf = np.array([])
        for cell,label,_ in (val_dl):
            image = cell.to(device)
            label = label.to(device)
            out = model(image)
            val_loss = ce_loss(out, label)         
            total_val_loss += loss
            y_true = np.squeeze(label.cpu().numpy())
            y_pred = np.squeeze(out.cpu().numpy())
            y_pred[np.where(y_pred>0.5)]=1
            y_pred[np.where(y_pred<=0.5)]=0
            y_truef = np.concatenate((y_truef,y_true),axis=None)
            y_predf = np.concatenate((y_predf,y_pred),axis=None)
        score = precision_recall_fscore_support(y_truef, y_predf,average='weighted')
        print(f'{epoch} -- Train loss: {total_loss/len(dl):.3f}')
        print(f'{epoch} -- Val loss: {total_val_loss/len(val_dl):.3f} Score(P,R,F1): {score}')
    
    scheduler.step()
    torch.save(model.state_dict(), f'/workspace/jay/DDP/Ocelot/classifier/ckpts_v1/{epoch}_{score[2]:.4f}.pt')
            

0 -- Train loss: 0.274
0 -- Val loss: 0.134 Score(P,R,F1): (0.8163800218926457, 0.808121296619031, 0.8111003248059154, None)
1 -- Train loss: 0.160
1 -- Val loss: 0.124 Score(P,R,F1): (0.8036028908861219, 0.7906936214708958, 0.7949966564445854, None)
2 -- Train loss: 0.119
2 -- Val loss: 0.109 Score(P,R,F1): (0.7999522027392307, 0.7859010108051586, 0.7905227421311295, None)
3 -- Train loss: 0.098
3 -- Val loss: 0.073 Score(P,R,F1): (0.81767908460072, 0.8014987800627397, 0.8063025426863146, None)
4 -- Train loss: 0.083
4 -- Val loss: 0.054 Score(P,R,F1): (0.789382925642532, 0.7846810735447891, 0.7866764156281165, None)
5 -- Train loss: 0.071
5 -- Val loss: 0.098 Score(P,R,F1): (0.7949721413184668, 0.7843325200418264, 0.788155500277924, None)
6 -- Train loss: 0.067
6 -- Val loss: 0.035 Score(P,R,F1): (0.815133442281015, 0.8050714534681074, 0.8085385864396897, None)
7 -- Train loss: 0.060
7 -- Val loss: 0.065 Score(P,R,F1): (0.8224605421733914, 0.8178807947019867, 0.8197178788122009, None