# Reference
- [transfer-1](https://officeguide.cc/pytorch-transfer-learning-resnet18-classify-mnist-tutorial-examples/)
- [transfer-2](https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html)

In [1]:
import pandas as pd
import numpy as np
from pathlib import Path
import pydicom
import copy
import time

import torch
from torch import nn, optim
from torchvision import transforms, io
from torchvision.transforms import functional as F

from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler

In [13]:
import random
np.random.seed(2022)
random.seed(2022)
torch.manual_seed(2022)

<torch._C.Generator at 0x7f5c8599ae10>

In [2]:
def display(tr: torch.Tensor):
    infos = {
        'min': torch.amin(tr),
        'max': torch.amax(tr),
        'dtype': tr.dtype,
        'size': tr.size()
    }

    return infos

In [3]:
preprocess = transforms.Compose([
    transforms.CenterCrop(50), transforms.Resize(224),
])

In [4]:
df = pd.read_csv("./data/DICOM/train.csv")

In [5]:
df1 = df.sample(frac=0.8, random_state=2022, ignore_index=True)

In [6]:
df2 = pd.concat( [df, df1] ).drop_duplicates(keep=False, ignore_index=True)

In [7]:
len(df1), len(df2), len(df)

(129, 32, 161)

In [8]:
intersected_df = pd.merge(df1, df2, how='inner')
print(intersected_df)

Empty DataFrame
Columns: [ID, Age, Gender, FilePath, index, Stage]
Index: []


In [14]:
class DicomDataset(torch.utils.data.Dataset):
    def __init__(self, root, train, transform):
        self.root = Path(root)
        self.transform = transform
        df = pd.read_csv( str(self.root/ "DICOM/train.csv") )

        # Train / Validation data
        train_df = df.sample(frac=0.8, random_state=2022, ignore_index=True)
        if train: self.list = train_df
        else: self.list = pd.concat( [df, df1] ).drop_duplicates(keep=False, ignore_index=True)

        # edit file path
        self.list.FilePath = self.list.FilePath.apply(lambda _: self.root / _[1:])

    def __len__(self):
        return len(self.list)

    def __getitem__(self, idx):
        dcm = pydicom.read_file( str(self.list.FilePath[idx]) )
        
        # label (1,2,3 -> 0.,1,2)
        label = int(self.list.Stage[idx]) - 1

        # Preprocessed Pixels: totensor, 3 channel
        pixel = dcm.pixel_array[ self.list.loc[idx, 'index'] ] # 用 index 當 column name 真的是天才
        # low, high = self.get_low_high(dcm)
        # pixeled = self.getWindow(pixel, low, high)
        # img = (pixeled - np.min(pixeled)) / (np.max(pixeled) - np.min(pixeled))
        img = torch.tensor(pixel.astype(np.float32))
        img = torch.stack([img, img, img], dim=0)

        seed = np.random.randint(1e9)
        random.seed(seed)
        torch.manual_seed(seed)

        img = self.transform(img)

        return img, label


In [15]:
training_data = DicomDataset(root="./data", train=True, transform=preprocess)
validation_data = DicomDataset(root="./data", train=False, transform=preprocess)