In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import torch, torchvision 
from PIL import Image

In [None]:
df_infos4 = pd.read_csv('../saved/infos4.csv')

In [None]:
import pickle
with open('../saved/data_imgs2.pkl', 'rb') as f:
    data_imgs2 = pickle.load(f)

In [None]:
plt.imshow(data_imgs2[5][3])

In [None]:
n_segments = 11
def crop_imgs(data_img, n_segments = n_segments, vis = False):
    
    shift_len = int((data_img[0].size[0]-224) / (n_segments-1))
    if vis:
        plt.imshow(data_img)
        plt.show()
        
    # each imgs
    imgs = []
    for i in range(n_segments):
        
        # each channel of each segment
        img_chns = []
        for chn in range(4):
            img = data_img[chn].crop((i*shift_len,0,i*shift_len+224,224)) # 0, 0, 224, 224 left, upper, right, and lower
            img_chns.append(img)
            
            if vis:
                plt.imshow(img)
                plt.show()
                
        imgs.append(img_chns)
    return imgs

data_img2_crops = []
for i in tqdm(range(len(data_imgs2))):
    data_img2_crops += crop_imgs(data_imgs2[i])
    

In [None]:
df_infos4_crops = df_infos4.loc[df_infos4.index.repeat(n_segments)].reset_index(drop=True)

In [None]:
from torch.utils.data import Dataset
from torchvision import transforms
from global_vars import labels
import os
class MyImageMultichannelDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, infos, n_segments, channel_imgs):
        """
                
        channel_imgs = chn -> array of PImage
        """
        self.infos = infos
        self.channel_imgs = channel_imgs
        self.n_segments = n_segments
        self.transform =  transforms.Compose([
                         transforms.ToTensor(),
                         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                         ])

    def __len__(self):
        return len(self.infos)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        images = [self.channel_imgs[idx][chn] for chn in range(4)]
        info_labels = self.infos.iloc[idx][labels]
        sample =(torch.cat([self.transform(image) for image in images],0), torch.Tensor(info_labels.astype(int)))

        return sample
    
image_datasets = MyImageMultichannelDataset(df_infos4_crops, n_segments, data_img2_crops)

In [None]:
imgs0, label0 = image_datasets[0]


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [None]:
from torch import nn
from torchvision import models
# import torch.nn.functional as F
# import torch.optim as optim

# def load_cwt_nn_model(device, model_saved_path='../saved/modelCWTFullWeightedLoss/modelCWTFull0_model.dict', 
#                       freeze=True):
#     model = models.resnet50(pretrained=True)
#     num_ftrs = model.fc.in_features
#     model.fc = nn.Linear(num_ftrs, 9)
#     model.load_state_dict(torch.load(model_saved_path, map_location=device))
    
#     if freeze:
#         for param in model.parameters():
#             param.requires_grad = False
#     # change the last output to 9 classes
#     model.fc = nn.Linear(num_ftrs, 1000)
    
#     # load saved model
    
#     model.eval()
#     model.to(device)
#     return model

# class MultiCWTNet(nn.Module):
#     def __init__(self, device, verbose=False):
#         super(MultiCWTNet, self).__init__()
        
#         self.chn_resnets = [load_cwt_nn_model(device) for i in range(4)]
#         num_ftrs = self.chn_resnets[0].fc.out_features
#         self.fc1 = nn.Linear(num_ftrs*4, 9)
#         self.fc2 = nn.Linear(9, 9)
#         self.verbose = verbose
        
#     def forward(self, xs):
#         x = [self.chn_resnets[i](xs[i]) for i in range(4)]
        
#         x = torch.cat(x, 1)
#         if self.verbose:
#             print('0: ', x.shape, x.device)
            
#         x = self.fc1(x)
#         if self.verbose:
#             print('2: ', x.shape)
            
#         x = F.relu(   x )     
#         if self.verbose:
#             print('3: ', x.shape)
            
#         x = self.fc2(x)
        
#         if self.verbose:
#             print('4: ', x.shape)
#         return x

# model = MultiCWTNet(device, verbose=True)
# model.to(device)
# for X,y in dataloaders:
#     xs = [x.to(device) for x in X]
#     model.forward(xs)
#     break

In [None]:
from torch import nn
from torchvision import models
    
class MultiCWTNet(nn.Module):
    def __init__(self, device, verbose=False):
        super(MultiCWTNet, self).__init__()
        
        self.resnet = models.resnet50(pretrained=True)
        self.resnet.conv1 = self.increase_channels(self.resnet.conv1, num_channels=12, copy_weights=0)
        num_ftrs = self.resnet.fc.in_features
        self.resnet.fc = nn.Linear(num_ftrs, 9)

        self.verbose = verbose
        
    def forward(self, xs):
        x = self.resnet(xs)
        return x
    
    
    def increase_channels(self, m, num_channels=None, copy_weights=0):
        """
        https://github.com/akashpalrecha/Resnet-multichannel/blob/master/multichannel_resnet.py
        
        takes as input a Conv2d layer and returns the a Conv2d layer with `num_channels` input channels
        and all the previous weights copied into the new layer.
        
        copy_weights (int): copy the weights of the channel (int)
        """
        # number of input channels the new module should have
        new_in_channels = num_channels if num_channels is not None else m.in_channels + 1
        
        # Creating new Conv2d layer
        new_m = nn.Conv2d(in_channels=new_in_channels, 
                          out_channels=m.out_channels, 
                          kernel_size=m.kernel_size, 
                          stride=m.stride, 
                          padding=m.padding,
                          bias=False)
        
        # Copying the weights from the old to the new layer
        new_m.weight[:, :m.in_channels, :, :] = m.weight.clone()
        
        #Copying the weights of the `copy_weights` channel of the old layer to the extra channels of the new layer
        for i in range(new_in_channels - m.in_channels): # 12 - 3
            channel = m.in_channels + i # 3，4，5，6，7，8，9，10，11
            new_m.weight[:, channel:channel+1, :, :] = m.weight[:, copy_weights:copy_weights+1, : :].clone()
        new_m.weight = nn.Parameter(new_m.weight)

        return new_m

In [None]:
def geometry_loss(fbeta, gbeta):
    return np.sqrt(fbeta*gbeta)

#geometry_loss(fbeta2, gbeta2)

In [None]:
from torchvision import datasets, models, transforms
from myeval import agg_y_preds_bags, binary_acc
import torch.optim as optim
from torch.optim import lr_scheduler
from snippets.pytorchtools import EarlyStopping
from sklearn.model_selection import GroupKFold
import time

st = time.time()
patience = 20
kf = GroupKFold(5)
batch_size=70

saved_dir = '../saved/modelMultiCWTFull/'
y = df_infos4_crops[labels].astype(int)

for i, (train_idx, test_idx) in enumerate(kf.split(df_infos4_crops, y, df_infos4_crops['ptID'])):
    
    if i < 2:
        continue
        
    trainDataset = torch.utils.data.Subset(image_datasets, train_idx)
    testDataset = torch.utils.data.Subset(image_datasets, test_idx)
    
    trainLoader = torch.utils.data.DataLoader(trainDataset, batch_size=batch_size, shuffle = True, pin_memory=True)#sampler = sampler)
    testLoader = torch.utils.data.DataLoader(testDataset, batch_size = batch_size, shuffle = False, pin_memory=True)

    model = MultiCWTNet(device, verbose=False)
    model.to(device)
    
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) #
    # Decay LR by a factor of 0.1 every 100 epochs
    scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
    
    pos_weight = np.ones(9) * 2
    pos_weight = torch.Tensor(pos_weight).to(device)
    
    df_y_train = df_infos4_crops.iloc[train_idx][labels].to_numpy().astype(int)
    class_weights = 1.0/np.sum(df_y_train,axis=0)
    class_weights = class_weights / np.sum(class_weights)
    class_weights = torch.Tensor(class_weights).to(device)
    criterion_train = nn.BCEWithLogitsLoss(weight=class_weights, pos_weight=pos_weight, reduction='sum')
       
    criterion_test = nn.BCEWithLogitsLoss(pos_weight=pos_weight, reduction='mean')

    losses_train = []
    losses_test = []

    avg_losses_train = []
    avg_losses_test = []


    early_stopping = EarlyStopping(patience, verbose=False, 
                                  saved_dir=saved_dir, 
                                   save_name='MutliCWTNetFull11'+str(i))
    epoch = 0
    auroc = 0
    auprc = 0
    accuracy = 0
    fmeasure = 0
    fbeta = 0
    gbeta = 0
    for epoch in range(25):
        
        model.train()
        output_trains = []
        y_trains = []
        for X_train, y_train in tqdm(trainLoader):
            y_train = y_train.to(device)
            X_train = X_train.to(device)
            optimizer.zero_grad()
            output_train = model(X_train)
            loss_train = criterion_train(output_train, y_train)
            losses_train.append(loss_train.item())
            loss_train.backward()
            optimizer.step()
            
            output_trains.append(output_train.cpu())
            y_trains.append(y_train.cpu())

        scheduler.step()
        
        avg_loss_train = np.average(losses_train)
        avg_losses_train.append(avg_loss_train)
        
        output_tests = []
        y_tests = []
        with torch.no_grad():
            model.eval()
            
            for X_test, y_test in testLoader:  
                y_test = y_test.to(device)
                X_test = X_test.to(device)
                output_test = model(X_test)
                
                loss_test = criterion_test(output_test, y_test)
                losses_test.append(loss_test.item())
                
                output_tests.append(output_test.cpu())
                y_tests.append(y_test.cpu())
                
            avg_loss_test = np.average(losses_test)
            avg_losses_test.append(avg_loss_test)
        
        
        
        y_trains = torch.cat(y_trains, axis=0)
        y_tests = torch.cat(y_tests, axis=0)
    
        output_trains = torch.cat(output_trains, axis=0)
        y_train_preds = torch.sigmoid(output_trains)
        
        output_tests = torch.cat(output_tests, axis=0)
        y_test_preds = torch.sigmoid(output_tests)
        
        #output_trains = torch.cat(output_trains, axis=0)
        y_train_preds_max, y_train_preds_mean, _ = agg_y_preds_bags(y_train_preds, bag_size=n_segments)
        y_test_preds_max, y_test_preds_mean, _ = agg_y_preds_bags(y_test_preds, bag_size=n_segments)
        _, _, y_trains = agg_y_preds_bags(y_trains, bag_size=n_segments)
        _, _, y_tests = agg_y_preds_bags(y_tests, bag_size=n_segments)
        
        for k, (y_train_preds, y_test_preds) in enumerate(zip([y_train_preds_max, y_train_preds_mean],
                                                              [y_test_preds_max, y_test_preds_mean])):

            acc, fmeasure, fbeta, gbeta, auroc, auprc = binary_acc(y_train_preds, y_trains)
            
            acc2, fmeasure2, fbeta2, gbeta2, auroc2, auprc2 = binary_acc(y_test_preds, y_tests)

            geometry = geometry_loss(fbeta, gbeta)
            geometry2 = geometry_loss(fbeta2, gbeta2)
            output_str = 'S{}/{} {:.2f} min {}|\n Train Loss: {:.6f}, Acc: {:.3f}, F: {:.3f}, Fbeta: {:.3f}, gbeta: {:.3f}, auroc: {:.3f}, auprc: {:.3f}, geo: {:.3f} |\nValid Loss: {:.6f}, Acc: {:.3f}, F: {:.3f}, Fbeta: {:.3f}, gbeta: {:.3f}, auroc: {:.3f}, auprc: {:.3f}, geo: {:.3f}\n '.format(
                i, epoch, (time.time()-st)/60, 'MEAN' if k == 1 else 'MAX',
                avg_loss_train, acc, fmeasure, fbeta, gbeta, auroc, auprc, geometry,
                avg_loss_test, acc2, fmeasure2, fbeta2, gbeta2, auroc2, auprc2, geometry2)
            print(output_str)

            with open(saved_dir+'loss11_{}.txt'.format(i), 'a') as f:
                print(output_str, file=f)
            
        early_stopping(-geometry2, model)

        if early_stopping.early_stop:
            print("Early stopping")
            break
            
    output_string = 'AUROC|AUPRC|Accuracy|F-measure|Fbeta-measure|Gbeta-measure|Geomotry\n{:.3f}|{:.3f}|{:.3f}|{:.3f}|{:.3f}|{:.3f}|{:.3f}'.format(auroc2,auprc2,acc2,fmeasure2,fbeta2,gbeta2,geometry2)
    print(output_string)     
    with open(saved_dir+'score'+ str(i)+ '_epoch' + str(epoch) + '.txt', 'w') as f:
        f.write(output_string)

    avg_losses_train = np.array(avg_losses_train)
    avg_losses_test = np.array(avg_losses_test)
    
    np.save(saved_dir + 'avg_losses_train' + str(i) + '_epoch' + str(epoch), avg_losses_train)
    np.save(saved_dir + 'avg_losses_test' + str(i) + '_epoch' + str(epoch), avg_losses_test)
    