In [1]:
import os, sys, gc, copy, itertools, json
sys.path.append(sys.path.append("src"))
from train_model import train_model
from model_analysis import model_analysis, torch_confusion_matrix
from plot_images import torch_to_PIL_single_image, ims_labels_to_grid, ims_preds_to_grid, ims_labels_preds_to_grid

import numpy as np
import pandas as pd
from tqdm import tqdm, tqdm_notebook

from functools import partial

from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import precision_recall_fscore_support

import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from torchvision import transforms, utils, models
from torchvision.utils import make_grid

from tensorboardX import SummaryWriter

import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

from PIL import Image, ImageDraw, ImageFont
from IPython.core.display import display

from pytorch_learning_tools.data_providers.DataframeDataProvider import DataframeDataProvider
from pytorch_learning_tools.data_providers.DataframeDataset import DataframeDataset, DatasetSingleRGBImageToTarget, DatasetSingleRGBImageToTargetUniqueID, DatasetHPA
from pytorch_learning_tools.utils.dataframe_utils import filter_dataframe
from pytorch_learning_tools.utils.data_utils import classes_and_weights
from pytorch_learning_tools.utils.data_loading_utils import loadPILImages

  from ._conv import register_converters as _register_converters


In [2]:
GPU_ID = 2
BATCH_SIZE = 128

In [3]:
data_loc = '/root/aics/modeling/gregj/projects/seq2loc/data/'

In [4]:
# read file
df = pd.read_csv(os.path.join(data_loc,'hpa_data_noNaNs.csv'))

# add absolute path
for channel in ['antibodyChannel', 'microtubuleChannel', 'nuclearChannel']:
    df[channel] = os.path.join(data_loc,'hpa') + os.path.sep + df['ENSG'] + os.path.sep + df[channel]
#     df = filter_dataframe(df,'',channel)

# add numeric labels
le = LabelEncoder()
df['targetNumeric'] = le.fit_transform(df['cellLine']).astype(np.int64)

# add unique id
df['uniqueID'] = df.index 

# print label map
label_map = dict(zip(le.classes_,[int(i) for i in le.transform(le.classes_)]))
print(json.dumps(label_map, indent = 2))

{
  "A-431": 0,
  "A549": 1,
  "AF22": 2,
  "ASC TERT1": 3,
  "BJ": 4,
  "CACO-2": 5,
  "HEK 293": 6,
  "HEL": 7,
  "HUVEC TERT2": 8,
  "HaCaT": 9,
  "HeLa": 10,
  "Hep G2": 11,
  "LHCN-M2": 12,
  "MCF7": 13,
  "NB-4": 14,
  "NIH 3T3": 15,
  "PC-3": 16,
  "REH": 17,
  "RH-30": 18,
  "RPTEC TERT1": 19,
  "RT4": 20,
  "SH-SY5Y": 21,
  "SK-MEL-30": 22,
  "SiHa": 23,
  "U-2 OS": 24,
  "U-251 MG": 25,
  "hTCEpi": 26
}


In [5]:
IM_SIZE=224; mask=torch.ones([3,IM_SIZE,IM_SIZE]); mask[1,:,:]=0

dataset_kwargs={split:{'tabularData':{'Sequence':'Sequence', 'cellLine':'targetNumeric', 'uniqueID':'uniqueID'},
                       'imageData':{'inputImage':{'cols':['microtubuleChannel', 'antibodyChannel', 'nuclearChannel'],
                                                  'loader':partial(loadPILImages, mode='L'),
                                                  'transform':transforms.Compose([transforms.ToPILImage(),
                                                                                  transforms.CenterCrop(IM_SIZE),
                                                                                  transforms.Resize(256),
                                                                                  transforms.CenterCrop(224),
                                                                                  transforms.ToTensor(),
                                                                                  transforms.Lambda(lambda x: mask*x)])}}} for split in ('train', 'test')}

dataloader_kwargs={split:{'batch_size':BATCH_SIZE, 'shuffle':True, 'drop_last':True, 'num_workers':4, 'pin_memory':True} for split in ('train', 'test')}

dp = DataframeDataProvider(df, uniqueID='uniqueID', dataset_kwargs=dataset_kwargs, dataloader_kwargs=dataloader_kwargs)

In [6]:
# i,mb = next(enumerate(dp.dataloaders['test']))
# ims_labels = [(im,label) for i,(im,label) in enumerate(zip(mb['inputImage'],mb['cellLine'])) if i<16]
# display(ims_labels_to_grid(ims_labels, ncol=2, not_boring_color=(0,0,0), boring_color=(0,0,0)))

In [7]:
classes,weights = classes_and_weights(dp, split='train', target_col='targetNumeric')
weights = weights.cuda(GPU_ID)
weights


 0.0001
 0.0009
 0.0071
 0.1072
 0.0026
 0.0010
 0.0006
 0.0370
 0.3572
 0.0024
 0.0007
 0.0009
 0.0893
 0.0005
 0.0765
 0.0006
 0.0007
 0.0134
 0.0008
 0.2679
 0.0012
 0.0012
 0.0018
 0.0015
 0.0001
 0.0001
 0.0268
[torch.cuda.FloatTensor of size (27,) (GPU 2)]

In [8]:
model_name = 'resnet18'
model_class = getattr(models, model_name)
model = model_class(pretrained=True)

model.fc = nn.Linear(model.fc.in_features, len(classes), bias=True)
model = model.cuda(GPU_ID)

In [9]:
N_epochs = 10

In [None]:
model = train_model(model, dp,
                    class_weights=weights,
                    class_names=le.classes_,
                    N_epochs=N_epochs,
                    phases=('train', 'test'),
                    learning_rate=1e-4,
                    gpu_id=GPU_ID)

In [None]:
torch.save(model.state_dict(), os.path.join('saved_models',model_name+'_{}epochs.pt'.format(N_epochs)))

In [None]:
model.eval()

mito_labels = {k:{'true_labels':[], 'pred_labels':[]} for k in dp.dataloaders.keys()}

for phase in dp.dataloaders.keys():
    for i, mb in tqdm_notebook(enumerate(dp.dataloaders[phase]), total=len(dp.dataloaders[phase]), postfix={'phase':phase}):
        
        x = mb['image']
        y = mb['target']
        
        y_hat_pred = model(Variable(x).cuda(GPU_ID))
        _, y_hat = y_hat_pred.max(1)
        
        mito_labels[phase]['true_labels'] += list(y.data.cpu().squeeze().numpy())
        mito_labels[phase]['pred_labels'] += list(y_hat.data.cpu().numpy())

In [None]:
model_analysis(mito_labels['train']['true_labels'], mito_labels['train']['pred_labels'])

In [None]:
model_analysis(mito_labels['test']['true_labels'], mito_labels['test']['pred_labels'])