In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)


class ConvToFC(nn.Module):
    def __init__(self, in_channels, in_dims, out_dims):
        super(ConvToFC, self).__init__()
        self.conv = nn.Conv2d(in_channels, 1, kernel_size=(in_dims, 1))
        self.relu = nn.ReLU(inplace=True)
        self.fc = nn.Linear(in_dims, out_dims)
    def forward(self, x):
        return self.fc(self.relu(self.conv(x)))

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=False):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = (DoubleConv(n_channels, 64))
        self.down1 = (Down(64, 128))
        self.down2 = (Down(128, 256))
        self.down3 = (Down(256, 512))
        factor = 2 if bilinear else 1
        self.down4 = (Down(512, 1024 // factor))
        self.up1 = (Up(1024, 512 // factor, bilinear))
        self.up2 = (Up(512, 256 // factor, bilinear))
        self.up3 = (Up(256, 128 // factor, bilinear))
        self.up4 = (Up(128, 64, bilinear))
        #self.outc = (OutConv(64, n_classes))
        self.outc = (ConvToFC(64, in_dims=512, out_dims=n_classes))

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

    def use_checkpointing(self):
        self.inc = torch.utils.checkpoint(self.inc)
        self.down1 = torch.utils.checkpoint(self.down1)
        self.down2 = torch.utils.checkpoint(self.down2)
        self.down3 = torch.utils.checkpoint(self.down3)
        self.down4 = torch.utils.checkpoint(self.down4)
        self.up1 = torch.utils.checkpoint(self.up1)
        self.up2 = torch.utils.checkpoint(self.up2)
        self.up3 = torch.utils.checkpoint(self.up3)
        self.up4 = torch.utils.checkpoint(self.up4)
        self.outc = torch.utils.checkpoint(self.outc)

In [2]:
from torch import optim, nn
from sklearn.model_selection import train_test_split

import torch
from torch.utils.data import DataLoader, random_split, Dataset
from tqdm import tqdm

import pandas as pd
import numpy as np
import os
import glob
import pydicom
import matplotlib.pyplot as plt
from time import time_ns


In [3]:
class CustomImageDataset(Dataset):
    def __init__(self, labels, images, transform=None, target_transform=None):
        self.labels = labels
        self.images = images
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label


In [4]:
data_location = "/kaggle/input/rsna-2024-lumbar-spine-degenerative-classification"

def load_dataset():
    # Load data, create dataset
    train_data = pd.read_csv(f'{data_location}/train.csv')

    train_images = os.listdir(f'{data_location}//train_images')
    train_images = list(filter(lambda x: x.find('.DS') == -1, train_images))
    train_images = [(x, f'{data_location}//train_images/{x}') for x in
                    train_images]

    image_metadata_set = {p[0]: {'folder_path': p[1],
                             'SeriesInstanceUIDs': []
                             }
                      for p in train_images}

    for m in image_metadata_set:
        image_metadata_set[m]['SeriesInstanceUIDs'] = list(
            filter(lambda x: x.find('.DS') == -1,
                   os.listdir(image_metadata_set[m]['folder_path'])
                   )
        )

    df_meta_f = pd.read_csv(f'{data_location}//train_series_descriptions.csv')

    for k in tqdm(image_metadata_set):
        for s in image_metadata_set[k]['SeriesInstanceUIDs']:
            if 'SeriesDescriptions' not in image_metadata_set[k]:
                image_metadata_set[k]['SeriesDescriptions'] = []
            try:
                image_metadata_set[k]['SeriesDescriptions'].append(
                    df_meta_f[(df_meta_f['study_id'] == int(k)) &
                              (df_meta_f['series_id'] == int(s))]['series_description'].iloc[0])
            except:
                print("Failed on", s, k)

    return train_data, image_metadata_set

In [5]:
def get_imageset_for_patient(sample_patient, train_data, image_metadata_set):
    # !TODO: Index by some patient id
    ptobj = image_metadata_set[str(sample_patient['study_id'])]

    im_list_dcm = {}
    for idx, i in enumerate(ptobj['SeriesInstanceUIDs']):
        im_list_dcm[i] = {'images': [], 'description': ptobj['SeriesDescriptions'][idx]}
        images = glob.glob(f"{ptobj['folder_path']}/{ptobj['SeriesInstanceUIDs'][idx]}/*.dcm")
        for j in sorted(images, key=lambda x: int(x.split('/')[-1].replace('.dcm', ''))):
            im_list_dcm[i]['images'].append({
                'SOPInstanceUID': j.split('/')[-1].replace('.dcm', ''),
                'dicom': pydicom.dcmread(j)})

    return {e['description']: [x['dicom'].pixel_array for x in e['images']] for e in im_list_dcm.values()}

In [6]:
LABELS = ['Normal/Mild', 'Moderate', 'Severe']
EXCLUDE_COLS = ["study_id"]
def convert_to_output_vector(train_data: pd.DataFrame):
    # !TODO: Naive approach
    # !TODO: Submission requires bin probabilities, rather than 1d score, need to think about that
    train_data_features = train_data[[e for e in train_data.columns if e not in EXCLUDE_COLS]]
    train_data_features = [[LABELS.index(e_) for e_ in e] for e in train_data_features.values]
    return train_data_features

In [8]:
def train() -> UNet:
    EPOCH_COUNT = 50

    study_data, image_metadata_set = load_dataset()
    study_data = study_data.dropna()
    # First approach: just get each individual image, tack on expected labels 0-2, train away
    study_data["features"] = convert_to_output_vector(study_data)
    study_data = study_data[["study_id", "features"]]

    # !TODO: Split within same study?
    # !TODO: Tripartite split vs bipartite?
    train_data, val_data = train_test_split(study_data, test_size=0.2)

    num_classes = len(study_data["features"].iloc[0])
    model = UNet(n_channels=1, n_classes=num_classes)

    exp = []
    train_set = []

    for i in range(10):
        image_examples = get_imageset_for_patient(train_data.iloc[i], train_data, image_metadata_set)
        train_set += [np.array(e, dtype=np.int32) for e in image_examples["Sagittal T1"]]
        exp += [train_data["features"].iloc[i] for e in image_examples["Sagittal T1"]]

    # !TODO: Loader args?
    train_dataset = CustomImageDataset(exp, train_set)
    train_loader = DataLoader(train_dataset, shuffle=True)
    #pred = model(list(train_loader)[0].float().unsqueeze(0))

    # Just the first one that comes to mind. To be tooned
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    # loss_fn = nn.CrossEntropyLoss()
    loss_fn = nn.L1Loss()

    for epoch in range(EPOCH_COUNT):
        total_loss = 0
        start = time_ns()
        for image, target in train_loader:
            # !TODO: Admit any image size
            if image.shape != (1, 512, 512):
                continue
            optimizer.zero_grad()
            output = model(image.float().unsqueeze(0))
            loss = loss_fn(output.squeeze(0).squeeze(0), torch.tensor(target).unsqueeze(0))
            total_loss += loss.item()
            loss.backward()
            optimizer.step()
        end = time_ns()
        print("Loss at epoch", epoch, total_loss)
        print("Seconds elapsed at epoch", epoch, (end - start) // 1e9)

    return model

In [None]:
model = train()