In [1]:
from google import colab
colab.drive.mount('/content/gdrive')
from collections import defaultdict

import numpy as np
import torch
from glob import glob

from tqdm import tqdm 
from torch.utils.data import DataLoader,TensorDataset
import torch.optim as optim
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
from torchvision import io
from torchvision import datasets, transforms, models
from torchsummary import summary
import torchvision.models as models

import matplotlib.pyplot as plt
import seaborn as sns

import gc
import shutil
import tarfile
import os


from sklearn.metrics import roc_auc_score

try : 
  import river
except :
  !pip install git+https://github.com/online-ml/river --upgrade
  import river

Mounted at /content/gdrive
Collecting git+https://github.com/online-ml/river
  Cloning https://github.com/online-ml/river to /tmp/pip-req-build-7rwa0dca
  Running command git clone -q https://github.com/online-ml/river /tmp/pip-req-build-7rwa0dca
Building wheels for collected packages: river
  Building wheel for river (setup.py) ... [?25l[?25hdone
  Created wheel for river: filename=river-0.1.0-cp37-cp37m-linux_x86_64.whl size=1642527 sha256=8dc1119474ed6daa0e19ab0e6153777a61ae1c6e992c71e275e6f5ddf4703af4
  Stored in directory: /tmp/pip-ephem-wheel-cache-ofwtqwfm/wheels/1f/de/e2/d95d67b57b9a0639417cd656aecc8e5be88665ac5b63c2bd1b
Successfully built river
Installing collected packages: river
Successfully installed river-0.1.0


In [3]:
data_path = '/content/gdrive/MyDrive/IDAO'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

params = {
    'batch_size': 32,
    'shuffle': True,
    'num_workers':2
    }
lr = 0.0001


# Load data


In [None]:
%%time

# Transfer data to the machine
shutil.copyfile(f'{data_path}/raw_data/track_1.tar', 'track_1.tar') 

my_tar = tarfile.open('track_1.tar')
my_tar.extractall('extract') # specify which folder to extract to
my_tar.close()

if not os.path.exists('/content/train'):
    os.mkdir('/content/train')

for img_ER in glob("/content/extract/idao_dataset/train/ER/*.png"):
  nrj = img_ER.split("_")[7]
  path = f'/content/train/{nrj}ER'
  if not os.path.exists(path):
    os.mkdir(path)
  shutil.move(img_ER, f"{path}/{img_ER.split('/')[-1]}")

for img_NR in glob("/content/extract/idao_dataset/train/NR/*.png"):
  nrj = img_NR.split("_")[8]
  path = f'/content/train/{nrj}NR'
  if not os.path.exists(path):
    os.mkdir(path)
  shutil.move(img_NR, f"{path}/{img_NR.split('/')[-1]}")

os.remove('track_1.tar')

### Splits creations

In [None]:
train_splits1 = [["6NR", "10ER", "20NR", "30ER"], ["1NR", "3ER", "20NR", "30ER"], ["1NR", "3ER", "6NR", "10ER"]]
val_splits1 = [["1NR", "3ER"], ["6NR", "10ER"], ["20NR", "30ER"]]

dict_splits1 = defaultdict(dict)
for i, (train, val) in enumerate(zip(train_splits1, val_splits1)):
  dict_splits1[i]["train"] = [glob(f"/content/train/{group}/*.png") for group in train]
  dict_splits1[i]["train"] = [item for sublist in dict_splits1[i]["train"] for item in sublist]
  dict_splits1[i]["val"] = [glob(f"/content/train/{group}/*.png") for group in val]
  dict_splits1[i]["val"] = [item for sublist in dict_splits1[i]["val"] for item in sublist]

train_splits2 = [["6NR", "3ER", "20NR", "30ER"], ["1NR", "10ER", "20NR", "30ER"], ["1NR", "3ER", "6NR", "10ER"]]
val_splits2 = [["1NR", "10ER"], ["6NR", "3ER"], ["20NR", "30ER"]]

dict_splits2 = defaultdict(dict)
for i, (train, val) in enumerate(zip(train_splits2, val_splits2)):
  dict_splits2[i]["train"] = [glob(f"/content/train/{group}/*.png") for group in train]
  dict_splits2[i]["train"] = [item for sublist in dict_splits2[i]["train"] for item in sublist]
  dict_splits2[i]["val"] = [glob(f"/content/train/{group}/*.png") for group in val]
  dict_splits2[i]["val"] = [item for sublist in dict_splits2[i]["val"] for item in sublist]

train_splits3 = [["6NR", "3ER", "20NR", "10ER"], ["1NR", "30ER", "20NR", "10ER"], ["1NR", "30ER", "6NR", "3ER"]]
val_splits3 = [["1NR", "30ER"], ["6NR", "3ER"], ["20NR", "10ER"]]

dict_splits3 = defaultdict(dict)
for i, (train, val) in enumerate(zip(train_splits3, val_splits3)):
  dict_splits3[i]["train"] = [glob(f"/content/train/{group}/*.png") for group in train]
  dict_splits3[i]["train"] = [item for sublist in dict_splits3[i]["train"] for item in sublist]
  dict_splits3[i]["val"] = [glob(f"/content/train/{group}/*.png") for group in val]
  dict_splits3[i]["val"] = [item for sublist in dict_splits3[i]["val"] for item in sublist]


train_splits4 = [ ["1NR", "3ER", "20NR", "10ER"], ["1NR", "10ER", "6NR", "3ER"]]
val_splits4 = [["6NR", "30ER"], ["20NR", "30ER"]]

dict_splits4 = defaultdict(dict)
for i, (train, val) in enumerate(zip(train_splits4, val_splits4)):
  dict_splits4[i]["train"] = [glob(f"/content/train/{group}/*.png") for group in train]
  dict_splits4[i]["train"] = [item for sublist in dict_splits4[i]["train"] for item in sublist]
  dict_splits4[i]["val"] = [glob(f"/content/train/{group}/*.png") for group in val]
  dict_splits4[i]["val"] = [item for sublist in dict_splits4[i]["val"] for item in sublist]


# Torch Dataset

In [None]:
class Dataset(torch.utils.data.Dataset):
  'Characterizes a dataset for PyTorch'
  def __init__(self, paths, train, rot=True, flip=True, gblur=-1):
    'Initialization'
    self.paths = paths
    self.train = train
    self.rot = rot
    self.flip = flip 
    self.gblur = gblur

  def __len__(self):
    'Denotes the total number of samples'
    return len(self.paths)

  def __getitem__(self, index):
    'Generates one sample of data'
    # Select sample
    ID = self.paths[index]
    # Load data and get label
    X = io.read_image(ID)
    X = transforms.ConvertImageDtype(torch.float32).forward(X)
    
    if self.train :
      #if self.perspective:
        #X = transforms.RandomPerspective(distortion_scale=0.1, p=0.5).forward(X)
      if self.rot:
        X = transforms.RandomRotation(180).forward(X)
      X = transforms.CenterCrop(256).forward(X)
      if self.flip:
        X = transforms.RandomHorizontalFlip(p=.5).forward(X)
        X = transforms.RandomVerticalFlip(p=.5).forward(X)
      if self.gblur>0:
        X = transforms.GaussianBlur(5, sigma=(0.1, self.gblur)).forward(X)
    else :
      X = transforms.CenterCrop(256).forward(X)

    classif_label = 1 if ID.split("/")[3][-2:] == "ER" else 0
    regression_label = float(ID.split("/")[3][:-2])
    y = (classif_label, regression_label)

    return X, y

# Train utils

### Fit / Validation

In [None]:
def fit_bin(model, criterion_class, metric, data_loader, device, optimizer):
  running_loss = 0.

  model.train()
  for X, (y_class, _) in tqdm(data_loader, total=int(len(data_loader.dataset)/data_loader.batch_size), position=0):
    X = X.to(device)
    y_class =  y_class.reshape(-1, 1).type(torch.float).to(device)

    optimizer.zero_grad()
    pred_class = model(X)

    loss =  criterion_class(pred_class, y_class)

    running_loss += loss.item()

    loss.backward()
    optimizer.step()
    for y_t, y_p in zip(y_class.cpu().detach().numpy().reshape(-1), pred_class.cpu().detach().numpy().reshape(-1)):
            metric = metric.update(y_t, y_p)
  
  train_loss = running_loss / len(data_loader.dataset)

  return train_loss, metric 


def fit_bin_l2reg(model, criterion_class, metric, data_loader, device, optimizer):
  running_loss = 0.
  running_class_loss = 0.
  model.train()
  for X, (y_class, _) in tqdm(data_loader, total=int(len(data_loader.dataset)/data_loader.batch_size), position=0):
    X = X.to(device)
    y_class =  y_class.reshape(-1, 1).type(torch.float).to(device)

    optimizer.zero_grad()
    pred_class = model(X)
    
    l2_reg = Variable( torch.FloatTensor(1), requires_grad=True)
    l2_reg = l2_reg.to(device)
    for W in model.parameters():
        l2_reg = l2_reg + W.norm(2)

    class_loss = criterion_class(pred_class, y_class)
    loss =   class_loss + 0.01*l2_reg

    running_loss += loss.item()
    running_class_loss += class_loss

    loss.backward()
    optimizer.step()
    for y_t, y_p in zip(y_class.cpu().detach().numpy().reshape(-1), pred_class.cpu().detach().numpy().reshape(-1)):
            metric = metric.update(y_t, y_p)
  
  train_loss = running_loss / len(data_loader.dataset)
  train_class_loss = running_class_loss / len(data_loader.dataset)

  return train_class_loss, metric 


def validate_bin(model, criterion_class, metric, data_loader, device):
  running_loss = 0.
  model.eval()

  list_class_true = []
  list_class_pred = []

  with torch.no_grad():
    for X, (y_class,_ ) in data_loader:
      X = X.to(device)
      y_class =  y_class.reshape(-1, 1).type(torch.float).to(device)

      pred_class = model(X)

      list_class_true += y_class.cpu().detach().numpy().ravel().tolist()
      list_class_pred += pred_class.cpu().detach().numpy().ravel().tolist()

      loss = criterion_class(pred_class, y_class)

      running_loss += loss.item()

      for y_t, y_p in zip(y_class.cpu().detach().numpy().reshape(-1), pred_class.cpu().detach().numpy().reshape(-1)):
            metric = metric.update(y_t, y_p)

  test_loss = running_loss / len(data_loader.dataset)

  res_df = pd.DataFrame(
      {
      'class_true': list_class_true,
      'class_pred': list_class_pred
       })
  
  return test_loss, res_df, metric

### Models generation

In [None]:
def make_resnet18():
  model = models.resnet18(pretrained=False)
  model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  model.fc = nn.Sequential(
      nn.Linear(in_features=512, out_features=256, bias=True),
      nn.Dropout(p=0.5),
      nn.ReLU(),
      nn.Linear(in_features=256, out_features=1),
      nn.Sigmoid())
  if torch.cuda.is_available():
      model.to(device)
  return model


def make_mobilenet_v2():
  model = models.mobilenet_v2(pretrained=False)
  model.features[0] = nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  model.classifier = nn.Sequential(
          nn.Linear(in_features=1280, out_features=128, bias=True),
          nn.Dropout(p=0.5),
          nn.ReLU(),
          nn.Linear(in_features=128, out_features=1),
          nn.Sigmoid())
  if torch.cuda.is_available():
      model.to(device)
  return model


def make_squeezenet1_0():
  model = models.squeezenet1_0(pretrained=False)
  model.features[0] = nn.Conv2d(1, 96, kernel_size=(7, 7), stride=(2, 2), bias=False)
  model.classifier = nn.Sequential(
        nn.Dropout(p=0.5),
        nn.Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1)),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=(3,3),stride=(3,3)),
        nn.Flatten(),
        nn.Linear(in_features=128*5*5,out_features=1),
        nn.Sigmoid())
  if torch.cuda.is_available():
      model.to(device)
  return model


def make_mobilenet_v3_small():
  model = models.mobilenet_v3_small(pretrained=False)
  model.features[0] = nn.Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  model.classifier = nn.Sequential(
      nn.Linear(in_features=576, out_features=128, bias=True),
      nn.Hardswish(),
      nn.Dropout(p=0.2),
      nn.Linear(in_features=128, out_features=1),
      nn.Sigmoid())
  if torch.cuda.is_available():
        model.to(device)
  return model


def make_resnext50_32x4d():
  model = models.resnext50_32x4d(pretrained=False)
  model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  model.fc = nn.Sequential(
      nn.Linear(in_features=2048, out_features=128, bias=True),
      nn.Dropout(p=0.5),
      nn.ReLU(),
      nn.Linear(in_features=128, out_features=1),
      nn.Sigmoid())
  if torch.cuda.is_available():
    model.to(device)
  return model

### Cross Validation

In [2]:
def pred_class(model,data_loader):
  model.eval()
  list_class_true = []
  list_class_pred = []
  list_kev_true = []
  with torch.no_grad():
    for X, (y_class, y_kev) in tqdm(data_loader,position=0):
      X = X.to(device)
      pred_class = model(X)
      list_class_true += y_class.cpu().detach().numpy().ravel().tolist()
      list_class_pred += pred_class.cpu().detach().numpy().ravel().tolist()
      list_kev_true += y_kev.cpu().detach().numpy().ravel().tolist()
  res_df = pd.DataFrame(
      {
      'class_true': list_class_true,
      'keV' : list_kev_true,
      'class_pred': list_class_pred
       })
  return res_df


def train_model(model,train_loader,test_loader,fit,optimizer,criterion,path):
  nb_epochs = 25

  train_loss = []
  train_auc = []
  test_loss = []
  test_auc = []

  best_loss = 10**4
  nb_stag = 1
  for i in range(nb_epochs):
    tmp_train_loss, tmp_train_auc = fit(
        model=model,
        criterion_class=criterion,
        data_loader=train_loader,
        metric=river.metrics.ROCAUC(),
        device='cuda',
        optimizer=optimizer
        )

    tmp_test_loss, df_res, tmp_test_auc = validate_bin(
        model=model,
        criterion_class=criterion,
        data_loader=test_loader,
        metric=river.metrics.ROCAUC(),
        device='cuda'
        )

    if tmp_test_loss < best_loss :
      nb_stag = 1
      best_loss = tmp_test_loss
      torch.save(model,f'{path}/best_model.pth')
      print(f'\nEpoch {i}/{nb_epochs}')
      print(f'Train : classif : {tmp_train_loss:.4f}; AUC : {tmp_train_auc}')
      print(f'Test  : classif : {tmp_test_loss:.4f}; AUC : {tmp_test_auc}')
    else :
      nb_stag += 1
    train_loss += [tmp_train_loss]
    train_auc += [tmp_train_auc.get()]
    test_loss += [tmp_test_loss]
    test_auc += [tmp_test_auc.get()]

    if nb_stag>=10 : 
      break
    
  model = torch.load(f'{path}/best_model.pth')
  df_res = pred_class(model, test_loader)
  df_res.to_csv(f'{path}/df_res.csv', index=False)
  plt.figure(figsize=(15,10))
  plt.subplot(1,2,1)
  plt.title('loss')
  plt.plot(train_loss,label='train')
  plt.plot(test_loss,label='test')
  plt.legend()
  plt.subplot(1,2,2)
  plt.title('auc')
  plt.plot(train_auc,label='train')
  plt.plot(test_auc,label='test')
  plt.savefig(f'{path}/loss.png')


def cross_val(model_maker,dict_splits,fit,criterion,path,rot,flip,gblur):
  if not os.path.exists(path):
    os.mkdir(path)
  for n_split in range(len(dict_splits)):
    print(f'\nCV split {n_split+1}/{len(dict_splits)}\n')
    training_set = Dataset(dict_splits[n_split]["train"], train=True, rot=rot, flip=flip, gblur=gblur)
    training_generator = torch.utils.data.DataLoader(training_set, **params)
    validation_set = Dataset(dict_splits[n_split]["val"], train=False)
    validation_generator = torch.utils.data.DataLoader(validation_set, **params)
    model = model_maker()
    optimizer = optim.Adam(model.parameters(), lr=0.0001)
    cv_path = f'{path}/cv_{n_split}'
    if not os.path.exists(cv_path):
      os.mkdir(cv_path)
    train_model(
        model=model,
        train_loader=training_generator,
        test_loader=validation_generator,
        fit=fit,
        optimizer=optimizer,
        criterion=criterion,
        path=cv_path)
  recap_class_pred(path,True)
  recap_roc(path,True)


def recap_class_pred(path,savefig=True):
  df_res = pd.DataFrame()
  for i in range(3):
    tmp_res = pd.read_csv(f'{path}/cv_{i}/df_res.csv')
    df_res = pd.concat([df_res,tmp_res])
  df_res.reset_index(drop=True,inplace=True)
  df_res.sort_values(by=['class_true','keV'], inplace=True)
  display(df_res.head())
  plt.figure(figsize=(15,8))
  plt.subplot(1,2,1)
  plt.title('Class 0 / ER / keV 1,6,20')
  plt.plot(df_res.loc[df_res.class_true==0].class_pred,'.')
  plt.subplot(1,2,2)
  plt.title('Class 1 / NR / keV 3,10,30')
  plt.plot(df_res.loc[df_res.class_true==1].class_pred,'.')
  if savefig:
    plt.savefig(f'{path}/recap_class_pred.png')
  else :
    plt.show()
  df_res.to_csv(f'{path}/df_res.csv')


def recap_roc(path,savefig=True):
  df_res = pd.read_csv(f'{path}/df_res.csv')
  roc_score = roc_auc_score(df_res.class_true, df_res.class_pred)
  plt.figure(figsize=(25,10))
  for i,split in enumerate([[1,3],[6,10],[20,30]]):
    tmp_res = df_res.loc[df_res.keV.isin(split)]
    fpr, tpr, threshold = roc_curve(tmp_res.class_true, tmp_res.class_pred)
    roc_auc = auc(fpr, tpr)
    plt.subplot(1,3,i+1)
    plt.title(f'ROC split : {split}')
    plt.plot(fpr, tpr, 'b', label = 'AUC = %0.4f' % roc_auc)
    plt.legend(loc = 'lower right')
    plt.plot([0, 1], [0, 1],'r--')
    plt.xlim([0, 1])
    plt.ylim([0, 1])
    plt.ylabel('True Positive Rate')
    plt.xlabel('False Positive Rate')
  if savefig:
    plt.savefig(f'{path}/recap_roc_{int(10000*roc_score)}.png')
  else :
    plt.show()

# Train

In [None]:
criterion_class = nn.BCELoss()

cross_val(
    model_maker=make_resnet18,
    dict_splits=dict_splits2, 
    fit = fit_bin_l2reg,
    criterion=criterion_class,
    rot = True,
    flip = True,
    gblur = 2,
    path='/content/gdrive/MyDrive/IDAO/models/class/resnet_l2reg_gblur2_split2'
    )
