## Data

In [60]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
from PIL import Image

import torch 
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split

torch.manual_seed(1111)

<torch._C.Generator at 0x7fb2e82931b0>

In [61]:
def crop(X):
    res = np.zeros((len(X), len(X[0]), 300, 300))
    for p in range(len(X)):
        for s in range(len(X[p])):
            for i in range(106, 406):
                res[p][s][i - 106] = X[p][s][i][106:406]
    return res

# MRIs
X = np.load('small_data.npy')
X = crop(X)


# labels
y = pd.read_csv('all_target.csv')
y.columns = ['y']

# devide into 2 classes: no cut / cut - (0 / 1)
y = np.where(y.y <= 3, 0, 1)


# Find all ill and healthy indeces
ill_inds = np.argwhere(y==1).flatten()
hea_inds = np.argwhere(y==0).flatten()

# Choose 5 from each group for further testing 
ill_test_inds = np.random.choice(ill_inds, 5, replace=False)
hea_test_inds = np.random.choice(hea_inds, 5, replace=False)

test_inds = [*ill_test_inds, *hea_test_inds]
train_inds = [i for i in range(len(y)) if i not in test_inds]

X_frozen = X[test_inds]
y_frozen = y[test_inds]

# Save just original balanced
real_ill_train = np.array([i for i in ill_inds if i not in ill_test_inds])
real_hea_train = np.random.choice([i for i in hea_inds if i not in hea_test_inds],
                                  len(real_ill_train), replace=False)

real_inds = []
for i in range(len(real_ill_train)):
    real_inds.append(real_ill_train[i])
    real_inds.append(real_hea_train[i])

X_real = X[real_inds]
y_real = y[real_inds]

# leave the rest and use as the main data
X = X[train_inds]
y = y[train_inds]

# Transformation
X = torch.from_numpy(X).to(torch.float32).reshape((710, 1, 1, 300, 300))
y = [[i]*10 for i in y]
y = torch.tensor(y).to(torch.float32).reshape(-1, 1)

X_frozen = torch.from_numpy(X_frozen).to(torch.float32).reshape((100, 1, 1, 300, 300))
y_frozen = [[i]*10 for i in y_frozen]
y_frozen = torch.tensor(y_frozen).to(torch.float32).reshape(-1, 1)

X_real = torch.from_numpy(X_real).to(torch.float32).reshape((160, 1, 1, 300, 300))
y_real = [[i]*10 for i in y_real]
y_real = torch.tensor(y_real).reshape(-1, 1)

new_inds = []
for i in range(10):
    for j in range(i, 160, 10):
        new_inds.append(j)

X_real = X_real[new_inds]
y_real = y_real[new_inds]

print(f'X shape = {X.shape}')
print(f'X_real shape = {X_real.shape}')
print(f'X_frozen shape = {X_frozen.shape}')

X shape = torch.Size([710, 1, 1, 300, 300])
X_real shape = torch.Size([160, 1, 1, 300, 300])
X_frozen shape = torch.Size([100, 1, 1, 300, 300])


## Balance

## Model

In [68]:
class deep_simple(nn.Module):
    def __init__(self, batch_size):
        super().__init__()
        
        self.batch_size = batch_size
        
        self.conv_layers = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5),
            nn.ReLU(),
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5),
            nn.ReLU(),
            nn.Conv2d(in_channels=16, out_channels=28, kernel_size=5),
            nn.MaxPool2d(2, 2)
        )
        
        self.class_layers = nn.Sequential(
            nn.Linear(28*144*144, 1000),
            nn.ReLU(),
            nn.Linear(1000, 500),
            nn.Linear(500, 100),
            nn.Linear(100, 10),
            nn.Linear(10, 2)
        )
        
    def forward(self, x):
        x = self.conv_layers(x)
        x = self.class_layers(x.reshape((self.batch_size, -1, 28*144*144)))
        
        return x

In [66]:
ds = deep_simple(1)
ds(X_real[0])

torch.Size([1, 28, 144, 144])


tensor([[[ 0.0750, -0.2078]]], grad_fn=<AddBackward0>)

In [69]:
# # Load a pretrained model
ds = deep_simple(1)
# ds.load_state_dict(torch.load('deep_simple_dict.pth'))
# ds.eval()
# print('Model is loaded')

# Real data training and testing
# training
n_epoch = 20
criterion = nn.CrossEntropyLoss()
optim = torch.optim.Adam(ds.parameters(), lr=0.01)

best_acc = 0.5
mod_cnt = 1
for epoch in range(n_epoch):
    
    i_cnt = 0
    loss_sum = 0
    for i, sample in enumerate(X_real):
        preds = ds(sample).reshape(-1, 2)
        lable = y_real[i]
        
        loss = criterion(preds, lable)
        loss.backward()
        optim.step()
        optim.zero_grad()
        
        i_cnt += 1
        loss_sum += loss
        
        if i_cnt == 10:
            print(f'Loss: {round(loss_sum.item() / 10, 4)}')
            i_cnt = 0
            loss_sum = 0
            
    with torch.no_grad():
        tp = 0
        fp = 0
        fn = 0
        tn = 0
        for j, samp in enumerate(X_frozen):
            preds = torch.argmax(ds(samp).reshape(-1, 2)).item()
            lable = y_frozen[j].item()

            if preds == lable:
                if lable == 1:
                    tp += 1
                else:
                    tn += 1

            else:
                if lable == 1:
                    fn += 1
                else:
                    fp += 1

        acc = (tp + tn) / len(y_frozen)
        print(f'Epochs passed: {epoch}, \t Frozen Accuracy: {acc}')
        print(f'{np.array([[tp, fp], [fn, tn]])}')

        if acc > best_acc:
            best_acc = acc
            torch.save(ds.state_dict(), f'ds_{mod_cnt}_dict.pth')
            mod_cnt += 1

Loss: 7004396.0
Loss: 22410.0484
Loss: 30158.025
Loss: 17417.0297
Loss: 3664.8367
Loss: 12491.407
Loss: 2538.0125
Loss: 923.3652
Loss: 1139.7498
Loss: 798.8555
Loss: 368.672
Loss: 178.0038
Loss: 192.7193
Loss: 247.5324
Loss: 175.038
Loss: 192.4241
Epochs passed: 0, 	 Frozen Accuracy: 0.53
[[26 23]
 [24 27]]
Loss: 123.8772
Loss: 244.8015
Loss: 30.869
Loss: 261.3204
Loss: 49.4534
Loss: 0.0
Loss: 99.4896
Loss: 50.364
Loss: 183.672
Loss: 171.8393
Loss: 148.2712
Loss: 124.401
Loss: 148.0999
Loss: 70.2467
Loss: 99.3141
Loss: 146.6992
Epochs passed: 1, 	 Frozen Accuracy: 0.56
[[35 29]
 [15 21]]
Loss: 65.18
Loss: 35.8452
Loss: 45.734
Loss: 36.9794
Loss: 24.9572
Loss: 33.3329
Loss: 0.0
Loss: 0.3978
Loss: 1.9328
Loss: 0.0
Loss: 6.8776
Loss: 25.3305
Loss: 12.3646
Loss: 43.478
Loss: 16.2007
Loss: 33.2071
Epochs passed: 2, 	 Frozen Accuracy: 0.58
[[23 15]
 [27 35]]
Loss: 49.9366
Loss: 8.7303
Loss: 45.8703
Loss: 60.2665
Loss: 21.3747
Loss: 15.9647
Loss: 47.6371
Loss: 141.3612
Loss: 0.0
Loss: 0.0
Los

KeyboardInterrupt: 

## Training & Testing