In [1]:
import torchvision
import torch
import torch.nn as nn
import pandas as pd
from torch.utils.data import DataLoader
from project.dataset import Dataset, VALDODataset
from project.preprocessing import NiftiToTensorTransform, z_score_normalization
from project.utils import collate_fn, plot_mri_slice, plot_all_slices, plot_all_slices_from_array, collatev2
from torchvision.models import resnet50, resnet18

In [2]:
resnet_model = resnet50(pretrained=True)



In [3]:
print(resnet_model)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [4]:
newmodel = torch.nn.Sequential(*(list(resnet_model.children())[:-1]))
print(newmodel)

Sequential(
  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)


In [5]:
import logging
logger = logging.getLogger('andy')
fh = logging.FileHandler('andy.log')
formatter = logging.Formatter(
    '%(asctime)s - %(levelname)s - %(message)s'
)

logger.setLevel(logging.DEBUG)
fh.setLevel(logging.DEBUG)
fh.setFormatter(formatter)

logger.addHandler(fh)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [6]:
ds = Dataset()

data = pd.read_csv('targets.csv')
data.shape

(7986, 3)

In [7]:
ch1 = ds.load_raw_mri(1)
data = data[data.mri.isin(ch1)]
data.shape

(385, 3)

In [8]:
transform = NiftiToTensorTransform(target_shape = (512, 512), rpn_mode=True)

cases = data.mri
masks = data.masks
target = data.target

In [9]:
dataset = VALDODataset(
    cases=cases,
    masks=masks,
    target=target,
    transform=transform,
    normalization=z_score_normalization,
)
dloader = DataLoader(
    dataset,
    shuffle=True,
    batch_size=10,
    collate_fn=collatev2,
)

In [10]:
class ResNetFeatureExtractor(nn.Module):
    def __init__(self):
        super(ResNetFeatureExtractor, self).__init__()
        
        self.resnet = resnet18(pretrained=True)

        # Remove the fc layer
        self.feature_extractor = nn.Sequential(*list(self.resnet.children())[:-1])
    
    def forward(self, image):
        output = []
        features = self.feature_extractor(image)
        
        return features

In [11]:
model = ResNetFeatureExtractor()
model = model.to(device)



In [12]:
for batch in dloader:
    for slices, masks, target, case in batch:
        shape = slices.shape
        slices = slices.view(shape[0], -1, shape[-2], shape[-1])
        print('Before conversion:', slices.shape)
        rgb_slice = slices.repeat(1, 3, 1, 1)
        print('After conversion:', rgb_slice.shape)
        rgb_slice = rgb_slice.float().to(device)
        with torch.no_grad():
            features = model(rgb_slice)
            print('Feature shape:', features.shape)
        break

Before conversion: torch.Size([35, 1, 256, 256])
After conversion: torch.Size([35, 3, 256, 256])


  return F.conv2d(input, weight, bias, self.stride,


Feature shape: torch.Size([35, 512, 1, 1])
Before conversion: torch.Size([35, 1, 256, 256])
After conversion: torch.Size([35, 3, 256, 256])
Feature shape: torch.Size([35, 512, 1, 1])
Before conversion: torch.Size([35, 1, 256, 256])
After conversion: torch.Size([35, 3, 256, 256])
Feature shape: torch.Size([35, 512, 1, 1])


KeyboardInterrupt: 

## `TODO:`
1. Check memory usage of each resnet