In [None]:
!pip install skorch

In [None]:
import os
import glob
import tqdm
import numpy as np
import pandas as pd
import pickle
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import skorch
from PIL import Image
from torchvision import transforms
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_absolute_error, r2_score, accuracy_score, roc_auc_score

## Using Pre-trained CNNs

<img src='https://i.stack.imgur.com/gI4zT.png'>

## Example: Cell Counting

In [None]:
# https://bbbc.broadinstitute.org/BBBC005
# https://data.broadinstitute.org/bbbc/BBBC005/BBBC005_v1_images.zip
# SIMCEPImages_well_Ccells_Fblur;_ssamples_wstain.TIF
os.system('wget https://data.broadinstitute.org/bbbc/BBBC005/BBBC005_v1_images.zip')
os.system('unzip BBBC005_v1_images.zip')

### View representative image

In [None]:
img = Image.open('/content/BBBC005_v1_images/SIMCEPImages_A07_C27_F1_s11_w2.TIF').convert('RGB')
img = transforms.Resize((299, 299))(img)
display(img)
img = transforms.ToTensor()(img)
print(img.shape)  

### Load dataset

In [None]:
def load_cell_dataset():
  X = []
  y = []
  imgs = sorted(glob.glob("BBBC005_v1_images/*.TIF"))[:2500]
  preprocess = transforms.Compose([transforms.Resize((299, 299)),
                                   transforms.RandomHorizontalFlip(),
                                   transforms.RandomVerticalFlip(),
                                   transforms.ToTensor(),
                                   # normalization used on training resnet-50 data
                                   transforms.Normalize(mean=[0.7137, 0.6628, 0.6519], \
                                                        std=[0.2970, 0.3017, 0.2979])])
  for i in tqdm.tqdm(imgs):
    # Convert image from B&W .TIF file to RGB image
    img = Image.open(i).convert('RGB')

    # Apply preprocessing
    img = preprocess(img)
    X.append(img)

    # Determine number of cells in the image from the filename
    name = os.path.basename(i)
    ncells = float(name.split("_")[2][1:])
    y.append(torch.tensor([ncells]))

  X = torch.stack(X, dim=0).float()
  y = torch.stack(y, dim=0).float()
  X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1)
  return X_train, X_test, y_train, y_test

X_train, X_test, y_train, y_test = load_cell_dataset()

### Initialize pre-trained model

In [None]:
# Load up pre-trained resnet-50 model and print the module information
model = torchvision.models.resnet50(pretrained=True)
model

### Re-initialize final layer of the pre-trained model

In [None]:
# Replace fc layer with regression output
model.fc = nn.Linear(2048, 1)

### Standardize data

In [None]:
y_mean = y_train.mean()
y_std = y_train.std()
y_train = (y_train-y_mean)/y_std
y_test = (y_test-y_mean)/y_std

### Fit the model

In [None]:
from skorch import NeuralNetRegressor
def optim(pgroups, **kwargs):
  return torch.optim.Adam(model.fc.parameters(), **kwargs)
regr = NeuralNetRegressor(model, batch_size=16, max_epochs=10, lr=2e-3, optimizer=optim, device='cuda')
regr.fit(X_train, y_train)

### Evaluate the model

In [None]:
from sklearn.metrics import r2_score
preds = regr.predict(X_test)
preds = preds*y_std.item() + y_mean.item()
targets = y_test.numpy()*y_std.item() + y_mean.item()
r2_score = r2_score(targets, preds)
mae_score = mean_absolute_error(targets, preds)
print(r2_score, mae_score)

In [None]:
plt.scatter(targets, preds)