In [1]:
import time
import pathlib
import os
import glob

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataset import random_split
from torchvision import datasets, transforms

import lightning as L
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks import TQDMProgressBar
from lightning.pytorch.callbacks import ModelCheckpoint
import torchmetrics

import timm

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from PIL import Image
import cv2

#from datamodules import Cifar10DataModule, MnistDataModule
from plotting import show_failures, plot_loss_and_acc
from utils import load_image

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class RecNet(nn.Module):
    def __init__(self):
        super().__init__()
        
        ###self.CNN = timm.create_model('resnet50', pretrained=True, num_classes=0)
        self.cnn = timm.create_model('resnet18', pretrained=True, num_classes=0, in_chans=1)
        for param in self.cnn.parameters():
            param.requires_grad = False
        in_features = self.cnn(torch.randn(2, 1, 112, 112)).shape[1]
        #in_feature = self.cnn.fc.in_features
        
        self.rnn = nn.GRU(input_size=in_features, hidden_size=64, batch_first= True, bidirectional=False)
        
        self.fc = nn.Linear(16256, 32, bias=True)
        self.classifier = nn.Linear(32, 2, bias=True)

    def forward(self, x, org):
        # x shape: BxTxCxHxW
        batch_size, timesteps, C, H, W = x.size()
        c_in = x.view(batch_size * timesteps, C, H, W)
        print('reshape input', c_in.shape)
        
        mask = self.mask_layer(org)
        
        out = self.cnn(c_in)
        print('CNN ouput', out.shape)
        
        rnn_in = out.view(batch_size, timesteps, -1)
        print('reshaped rnn_in', rnn_in.shape)
        out, hd = self.rnn(rnn_in)
        
        #out =F.relu(self.RNN(out))
        print('RNN ouput', out.shape)
        #print('RNN hidden', hd.shape)
        
        out = out * mask
        print('mask ouput', out.shape)
        
        batch, timesteps, r_features = out.size() 
        #out = out.view(batch_size, timesteps * r_features)
        out = out.reshape(batch_size, timesteps * r_features)
        print('reshaped masked output', out.shape)
        
        out = F.relu(self.fc(out))
        print('fc ouput', out.shape)

        logits = self.classifier(out)
        print('classifier ouput', logits.shape)
        
        #output = F.softmax(logits, dim=1)
        #print('prb ouput', output.shape)
        #output = F.softmax(logits) #[prob 0, prob 1]

        #return output
        return logits

    def mask_layer(self, org):
        masks = []
        for i in org:
            dup = 254 - i
            mask_1 = torch.ones(i, 64)
            mask_0 = torch.zeros(dup, 64)
            mask = torch.cat((mask_1, mask_0), 0)
            masks.append(mask)
            #print(mask.shape)
        masks = torch.stack(masks).to(device='cuda')
        print('masks', masks.shape)
        return masks

In [3]:
class RSNAdataset(Dataset):
    def __init__(self, patient_path, paths, targets, n_slices, img_size, transform=None):
        #(self, './data/reduced_dataset/', t['xtrain'],t['ytrain'], 254, 112, transform)
        self.patient_path = patient_path
        self.paths = paths
        self.targets = targets
        self.n_slices = n_slices
        self.img_size = img_size
        self.transform = transform
          
    def __len__(self):
        #print(len(self.paths))
        return len(self.paths)
    
    def padding(self, paths):
        
        images=[load_image(path) for path in paths]
        org_size = len(images)

        #if len(images) != 0:
            
        dup_len = 254 - len(images)
        if org_size == 0:
            dup = torch.zeros(self.n_slices, 112, 112)
        else:
            dup = images[-1]
        for i in range(dup_len):
            images.append(dup)

        images = [torch.tensor(image, dtype=torch.float32) for image in images]

        #if len(images)==0:
        #    images = torch.zeros(self.n_slices, 112, 112)
        #else:
        images = torch.stack(images)

        return images, org_size
    
    '''def read_video(self, vid_paths):
        video = [load_image(path, (self.img_size, self.img_size)) for path in vid_paths]
        if self.transform:
            seed = random.randint(0,99999)
            for i in range(len(video)):
                random.seed(seed)
                video[i] = self.transform(image=video[i])["image"]
        
        video = [torch.tensor(frame, dtype=torch.float32) for frame in video]
        if len(video)==0:
            video = torch.zeros(self.n_frames, self.img_size, self.img_size)
        else:
            video = torch.stack(video) # T * C * H * W
        return video'''
    
    def __getitem__(self, index):
        _id = self.paths[index]
        patient_path = os.path.join(self.patient_path, f'{str(_id).zfill(5)}/')

        data = []
        org = []
        for t in ["FLAIR", "T1w", "T1wCE", "T2w"]:
            t_paths = sorted(
                glob.glob(os.path.join(patient_path, t, "*")), 
                key=lambda x: int(x[:-4].split("-")[-1]),
            )
            num_samples = self.n_slices
            ##if len(t_paths) < num_samples:
             #   in_frames_path = t_paths
            #else:
             #   in_frames_path = uniform_temporal_subsample(t_paths, num_samples)
            
            image, org_size = self.padding(t_paths)
            if image.shape[0] == 0:
                image = torch.zeros(num_samples, self.img_size, self.img_size)
            data.append(image)
            org.append(org_size)
            break
            
        data = torch.stack(data).transpose(0,1)
        #print(data.shape)
        #print('after transpose', data.shape)
        y = torch.tensor(self.targets[index], dtype=torch.float)
        return {"X": data.float(), "y": y}, org

In [4]:
folds_xtrain = np.load('./data/folds/xtrain.npy', allow_pickle=True)
folds_xtest = np.load('./data/folds/xtest.npy', allow_pickle=True)
folds_ytrain = np.load('./data/folds/ytrain.npy', allow_pickle=True)
folds_ytest = np.load('./data/folds/ytest.npy', allow_pickle=True)

xtrain = folds_xtrain[4]
ytrain = folds_ytrain[4]
xtest = folds_xtest[4]
ytest = folds_ytest[4]

print('-'*30)
print(f"Fold {'3'}")

------------------------------
Fold 3


In [5]:
train_retriever = RSNAdataset(
    'data/reduced_dataset/',
    xtrain,  
    ytrain,
    n_slices=254,
    img_size=112,
    transform=None
        )

In [6]:
dict, org = train_retriever[0]

In [7]:
batch = dict['X']
targets = dict['y']

In [8]:
batch.shape
#targets.shape


torch.Size([254, 1, 112, 112])

In [9]:
targets

tensor(1.)

In [10]:
org

[33]

In [6]:
train_loader = DataLoader(
            train_retriever,    
            batch_size=1,
            shuffle=True,
            num_workers=8,
        )

In [7]:
model = RecNet()
model.to(device='cuda')
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
criterion = F.cross_entropy
model.train()

RecNet(
  (cnn): ResNet(
    (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act1): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (drop_block): Identity()
        (act1): ReLU(inplace=True)
        (aa): Identity()
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act2): ReLU(inplace=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bi

In [None]:
X = batch.to(device='cuda')
X = X.unsqueeze(0)
print(X.shape)
y = targets.to(device='cuda')
print(y.shape)

#self.optimizer.zero_grad()
outputs = model(X, org).squeeze()
#outputs = outputs.squeeze()
print(outputs.shape)
print(outputs)

In [8]:
for i, batch in enumerate(train_loader, 1):
    X = batch[0]['X'].to(device='cuda')
    print('train_loader output',X.shape)
    y = batch[0]['y'].to(device='cuda')
    print('train_loader targets',y.shape)
    org = batch[1][0]
    print('train org', org)
    #count += 1
    
    #optimizer.zero_grad()
    outputs = model(X, org).squeeze(1)
    break
    #print('prob outputs', outputs.shape)
    
    #loss = criterion(outputs, y)
    #loss.backward()

    #train_loss.update(loss.detach().item())
    #train_score.update(targets, outputs.detach())
    
    #self.optimizer.step()
    
    #_loss, _score = train_loss.avg, train_score.avg
    #message = 'Train Step {}/{}, train_loss: {:.5f}, train_score: {:.5f}, train_f1: {:.5f}'
    #self.info_message(message, step, len(train_loader), _loss, _score, ff, end="\r")

    #f_score = ff_score.get_score()
    #return train_loss.avg, train_score.avg, f_score, int(time.time() - t)
    

train_loader output torch.Size([1, 254, 1, 112, 112])
train_loader targets torch.Size([1])
train org tensor([110])
reshape input torch.Size([254, 1, 112, 112])
masks torch.Size([1, 254, 64])
CNN ouput torch.Size([254, 512])
reshaped rnn_in torch.Size([1, 254, 512])
RNN ouput torch.Size([1, 254, 64])
mask ouput torch.Size([1, 254, 64])
reshaped masked output torch.Size([1, 16256])
fc ouput torch.Size([1, 32])
classifier ouput torch.Size([1, 2])


In [19]:
timm.list_models('*convnext*', pretrained=True)

['convnext_atto.d2_in1k',
 'convnext_atto_ols.a2_in1k',
 'convnext_base.clip_laion2b',
 'convnext_base.clip_laion2b_augreg',
 'convnext_base.clip_laion2b_augreg_ft_in1k',
 'convnext_base.clip_laion2b_augreg_ft_in12k',
 'convnext_base.clip_laion2b_augreg_ft_in12k_in1k',
 'convnext_base.clip_laion2b_augreg_ft_in12k_in1k_384',
 'convnext_base.clip_laiona',
 'convnext_base.clip_laiona_320',
 'convnext_base.clip_laiona_augreg_320',
 'convnext_base.clip_laiona_augreg_ft_in1k_384',
 'convnext_base.fb_in1k',
 'convnext_base.fb_in22k',
 'convnext_base.fb_in22k_ft_in1k',
 'convnext_base.fb_in22k_ft_in1k_384',
 'convnext_femto.d1_in1k',
 'convnext_femto_ols.d1_in1k',
 'convnext_large.fb_in1k',
 'convnext_large.fb_in22k',
 'convnext_large.fb_in22k_ft_in1k',
 'convnext_large.fb_in22k_ft_in1k_384',
 'convnext_large_mlp.clip_laion2b_augreg',
 'convnext_large_mlp.clip_laion2b_augreg_ft_in1k',
 'convnext_large_mlp.clip_laion2b_augreg_ft_in1k_384',
 'convnext_large_mlp.clip_laion2b_augreg_ft_in12k_384',

In [None]:
res

In [26]:
res = timm.create_model('resnet50', pretrained=True, num_classes=0, in_chans=1)
#m = res(torch.randn(2, 3, 224, 224))
res.reset_classifier(0)
o = res(torch.randn(2, 1, 112, 112))
print(f'Pooled shape: {o.shape[1]}')
print(res.fc)
in_features = res(torch.randn(2, 1, 112, 112)).shape[1]
print(in_features)

Pooled shape: 2048
Identity()
2048


In [None]:
mask_1 = torch.ones(200, 64)
mask_0 = torch.zeros(50, 64)

mask = torch.cat((mask_1, mask_0), 0)
mask.shape

In [20]:
out = torch.ones(254, 1, 112, 112)
cpy = out[-1]
cpy.shape

torch.Size([1, 112, 112])

In [6]:
list = os.listdir('data/reduced_dataset/00123/FLAIR')
dup_len = 254 - len(list)
dup_len

for i in range(dup_len):
    

60

In [None]:
list = list(pathlib.Path('data/reduced_dataset/00002/FLAIR/').rglob("*.png"))
list

In [None]:
for path in pathlib.Path('data/reduced_dataset/00002/FLAIR/').rglob("*.png"):
    print(path.name)

In [9]:
p = []
for path in os.listdir('data/reduced_dataset/00002/FLAIR/'):
    p.append(os.path.join('data/reduced_dataset/00002/FLAIR/', path))

In [12]:
p[0]

'data/reduced_dataset/00002/FLAIR/Image-460.png'

In [18]:
#images = [load_image(path) for i in p]
images=[]
for i in p:
    #print(i)
    images.append(load_image(i))
print(len(images))

list = os.listdir('data/reduced_dataset/00005/FLAIR')
dup_len = 254 - len(images)
dup_len
print(dup_len)

dup = images[-1]
for i in range(dup_len):
    images.append(dup)
print(len(images))

47
207
254


In [21]:
images = [torch.tensor(frame, dtype=torch.float32) for frame in images]

In [23]:
images = torch.stack(images)

In [24]:
images.shape

torch.Size([254, 112, 112])

In [None]:
t_paths = sorted(
                glob.glob(os.path.join('data/reduced_dataset/00005/', 'FLAIR', "*")), 
                key=lambda x: int(x[:-4].split("-")[-1]),
            )
t_paths

In [17]:
for i in xtrain:
    if i == 123:
        print(i)

123


In [21]:
count

235

In [32]:
masks = []
org = [200, 180,170,210]

for i in org:
    dup = 254 - i
    mask_1 = torch.ones(i, 64)
    mask_0 = torch.zeros(dup, 64)
    mask = torch.cat((mask_1, mask_0), 0)
    masks.append(mask)
    print(mask.shape)
m = torch.stack(masks)
print(m.shape)

torch.Size([254, 64])
torch.Size([254, 64])
torch.Size([254, 64])
torch.Size([254, 64])
torch.Size([4, 254, 64])


In [27]:
masks[0]

tensor([[1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [28]:
masks[1]

tensor([[1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [29]:
masks[2]

tensor([[1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [30]:
masks[3]

tensor([[1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [31]:
m

tensor([[1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        [1., 1., 1.,  ..., 1., 1., 1.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [None]:
%reload_ext watermark
%watermark -a 'Karanjot Vendal' -v -p torch --iversion