In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import nibabel as nib
import pandas as pd
import os
import shutil
from torch.utils.data import DataLoader

In [2]:
if torch.cuda.is_available():
    DEVICE = torch.device('cuda:0')
else:
    DEVICE = torch.device('cpu')
    
print("Using PyTorch version:", torch.__version__, 'Device: ', DEVICE)
    

Using PyTorch version: 1.12.0+cu116 Device:  cuda:0


In [3]:
path2xls = './IXI.xls'
labels_df = pd.read_excel(path2xls)
labels_df.head()

Unnamed: 0,IXI_ID,"SEX_ID (1=m, 2=f)",HEIGHT,WEIGHT,ETHNIC_ID,MARITAL_ID,OCCUPATION_ID,QUALIFICATION_ID,DOB,DATE_AVAILABLE,STUDY_DATE,AGE
0,1,1,170,80,2,3,5,2,1968-02-22,0,NaT,
1,2,2,164,58,1,4,1,5,1970-01-30,1,2005-11-18,35.800137
2,12,1,175,70,1,2,1,5,1966-08-20,1,2005-06-01,38.781656
3,13,1,182,70,1,2,1,5,1958-09-15,1,2005-06-01,46.710472
4,14,2,163,65,1,4,1,5,1971-03-15,1,2005-06-09,34.236824


In [4]:
labels_df['IXI_ID']

0        1
1        2
2       12
3       13
4       14
      ... 
614    652
615    653
616    655
617    660
618    662
Name: IXI_ID, Length: 619, dtype: int64

In [5]:
age_df = labels_df[['IXI_ID', 'AGE']]
age_df

Unnamed: 0,IXI_ID,AGE
0,1,
1,2,35.800137
2,12,38.781656
3,13,46.710472
4,14,34.236824
...,...,...
614,652,42.989733
615,653,46.220397
616,655,
617,660,


In [6]:
path_dir = './IXI'
file_list = os.listdir(path_dir)
file_list

['swc1nIXI016-Guys-0697-IXI3DMPRAG_-s231_-0301-00003-000001-01_RAS_denoised.nii',
 'swc1nIXI138-Guys-0746-IXI3DMPRAG_-s240_-0401-00004-000001-01_RAS_denoised.nii',
 'swc1nIXI077-Guys-0752-IXI3DMPRAG_-s242_-0301-00003-000001-01_RAS_denoised.nii',
 'swc1nIXI468-Guys-0985-MPRAGESEN_-s417_-0301-00003-000001-01_RAS_denoised.nii',
 'swc1nIXI629-Guys-1095-MPRAGESEN_-sIXI62_-0301-00003-000001-01_RAS_denoised.nii',
 'swc1nIXI496-Guys-1045-MPRAGESEN_-s421_-0301-00003-000001-01_RAS_denoised.nii',
 'swc1nIXI639-Guys-1088-MPRAGESEN_-s445_-0301-00003-000001-01_RAS_denoised.nii',
 'swc1nIXI336-Guys-0904-MPRAGESEN_-s306_-0301-00003-000001-01_RAS_denoised.nii',
 'swc1nIXI420-Guys-1028-MPRAGESEN_-s424_-0301-00003-000001-01_RAS_denoised.nii',
 'swc1nIXI307-IOP-0872-SAGFSPGR_-sIXI30_-0003-00001-000001-01_RAS_denoised.nii',
 'swc1nIXI219-Guys-0894-MPRAGESEN_-s304_-0301-00003-000001-01_RAS_denoised.nii',
 'swc1nIXI303-IOP-0968-SAGFSPGR_-sIXI30_-0003-00001-000001-01_RAS_denoised.nii',
 'swc1nIXI178-Guys-0778

In [7]:
int(file_list[0][8:11])

def indexing(x):
    return int(x[8:11])
    
    

In [8]:
f_list=[]
for f in file_list:
    x = indexing(f)
    f_list.append(x)
    
f_list
    

[16,
 138,
 77,
 468,
 629,
 496,
 639,
 336,
 420,
 307,
 219,
 303,
 178,
 423,
 309,
 279,
 344,
 153,
 123,
 536,
 331,
 143,
 474,
 450,
 528,
 268,
 170,
 233,
 87,
 587,
 640,
 368,
 222,
 431,
 193,
 44,
 76,
 158,
 119,
 27,
 495,
 114,
 662,
 109,
 200,
 188,
 300,
 393,
 442,
 576,
 351,
 461,
 573,
 597,
 164,
 574,
 207,
 26,
 192,
 28,
 413,
 594,
 63,
 621,
 294,
 399,
 498,
 641,
 84,
 232,
 196,
 417,
 332,
 24,
 493,
 411,
 50,
 552,
 37,
 209,
 55,
 390,
 315,
 401,
 548,
 17,
 113,
 154,
 595,
 45,
 652,
 342,
 653,
 36,
 265,
 500,
 64,
 182,
 115,
 415,
 549,
 388,
 463,
 23,
 112,
 488,
 308,
 359,
 433,
 20,
 408,
 365,
 418,
 381,
 364,
 469,
 430,
 370,
 288,
 241,
 454,
 35,
 330,
 371,
 43,
 480,
 416,
 644,
 157,
 625,
 118,
 73,
 464,
 426,
 183,
 648,
 456,
 409,
 230,
 348,
 592,
 166,
 324,
 61,
 558,
 172,
 41,
 491,
 523,
 551,
 312,
 305,
 503,
 197,
 434,
 224,
 517,
 74,
 375,
 400,
 563,
 120,
 406,
 326,
 285,
 186,
 100,
 286,
 25,
 199,
 266,
 

In [9]:
len(f_list)

381

In [10]:
IXI_df1 = pd.DataFrame(zip(f_list, file_list), columns = ['IXI_ID', 'file_name'])

In [11]:
IXI_df2 = pd.merge(IXI_df1, age_df, how='outer', on='IXI_ID')


In [12]:
IXI_df3 = IXI_df2.drop_duplicates(['IXI_ID']
                                  )
IXI_df4 = IXI_df3.dropna()

In [13]:
IXI_df4 = IXI_df4.drop('IXI_ID',axis=1)
IXI_df4 

Unnamed: 0,file_name,AGE
0,swc1nIXI016-Guys-0697-IXI3DMPRAG_-s231_-0301-0...,55.167693
1,swc1nIXI138-Guys-0746-IXI3DMPRAG_-s240_-0401-0...,33.138946
2,swc1nIXI077-Guys-0752-IXI3DMPRAG_-s242_-0301-0...,36.479124
3,swc1nIXI468-Guys-0985-MPRAGESEN_-s417_-0301-00...,67.698836
4,swc1nIXI629-Guys-1095-MPRAGESEN_-sIXI62_-0301-...,59.263518
...,...,...
399,swc1nIXI373-IOP-0967-SAGFSPGR_-sIXI37_-0003-00...,58.792608
400,swc1nIXI428-Guys-0996-MPRAGESEN_-s419_-0301-00...,55.540041
401,swc1nIXI531-Guys-1057-MPRAGESEN_-s427_-0301-00...,75.937029
403,swc1nIXI458-Guys-0993-MPRAGESEN_-s418_-0301-00...,70.710472


In [14]:

from torch.utils.data import Dataset

torch.manual_seed(0)

class PBA_Dataset(Dataset):
    def __init__(self, data_dir, dataframe ,  transform=None):
        path2data = data_dir
        
        filenames = dataframe['file_name'].to_list()
        
        self.full_filenames = [os.path.join(path2data, f) for f in filenames]
        
        labels_df = dataframe
        labels_df.set_index('file_name', inplace=True)
        
        self.labels = [labels_df.loc[filename].values[0] for filename in filenames]
        
        self.transform = transform
        
    def __len__(self):
        return len(self.full_filenames)
    def __getitem__(self, idx):
        image_3d = nib.load(self.full_filenames[idx])
        image = image_3d.get_fdata()
        image = (image - image.min())/(image.max()-image.min())
        image = image*255.0
        image = np.expand_dims(image, axis=0)
        
        label = self.labels[idx]
        return image, label
    
import torchvision.transforms as transforms
data_transformer = transforms.Compose([transforms.ToTensor()])

data_dir = './IXI/'
    

In [15]:
from sklearn.model_selection import train_test_split

def get_data(data_file):
    data = data_file
    print("data.shape: {}".format(data.shape))

    trainingSet, test_df = train_test_split(data, test_size=0.2, random_state=45346)
    train_df, val_df = train_test_split(trainingSet, test_size=0.2, random_state=257572)

    print("Train.shape: {}, Val.shape: {}, Test.shape: {}".format(train_df.shape, val_df.shape, test_df.shape))

    data.loc[train_df.index, 'split_type'] = "Train"
    data.loc[val_df.index, 'split_type'] = "Val"
    data.loc[test_df.index, 'split_type'] = "Test"

    train_df = data.loc[train_df.index, :]
    val_df = data.loc[val_df.index, :]
    test_df = data.loc[test_df.index, :]

    return train_df, val_df, test_df


In [16]:
train_df, val_df, test_df = get_data(IXI_df4)

data.shape: (381, 2)
Train.shape: (243, 2), Val.shape: (61, 2), Test.shape: (77, 2)


In [17]:
train_data = PBA_Dataset(data_dir, train_df, data_transformer)
val_data = PBA_Dataset(data_dir, val_df, data_transformer)
test_data = PBA_Dataset(data_dir, test_df, data_transformer)

In [18]:
from torch.utils.data import DataLoader

In [19]:
# Hyper parameters
NUM_EPOCHS = 100
BATCH_SIZE = 16
LEARNING_RATE = 0.001
IMAGE_SHAPE = (121, 145, 121)


In [20]:
trn_loader = DataLoader(train_data, shuffle=True, num_workers=4, batch_size=BATCH_SIZE, drop_last=True)
val_loader = DataLoader(val_data, shuffle=True, num_workers=4, batch_size=BATCH_SIZE, drop_last=True)
tst_loader = DataLoader(test_data, shuffle=True,batch_size=BATCH_SIZE)

In [21]:
print(trn_loader)

<torch.utils.data.dataloader.DataLoader object at 0x7f5d07d7edd0>


In [22]:
for (X_train, y_train) in trn_loader:

    print('X_train:', X_train.size(), 'type:', X_train.type())
    print('y_train:', y_train.size(), 'type:', y_train.type())
    print(y_train)
    break

X_train: torch.Size([1, 1, 121, 145, 121]) type: torch.DoubleTensor
y_train: torch.Size([1]) type: torch.DoubleTensor
tensor([42.2204], dtype=torch.float64)


In [23]:
'''class Model3D(nn.Module):
    def __init__(self, n_classes = 1):
        super(Model3D,self).__init__()
        self.conv1 = nn.Conv3d(1, 32, kernel_size=3,stride=(1,1,1),padding=1)
        self.conv2 = nn.Conv3d(32, 64, kernel_size=3,stride=(1,1,1),padding=1)
        self.conv3 = nn.Conv3d(64, 128, kernel_size=3,stride=(1,1,1),padding=1)
        self.conv4 = nn.Conv3d(128, 256, kernel_size=3,stride=(1,1,1),padding=1)
        self.conv5 = nn.Conv3d(256, 256, kernel_size=3,stride=(1,1,1),padding=1)
        self.conv6 = nn.Conv3d(256, 64, kernel_size=1,stride=(1,1,1))

        self.batchnorm1 = nn.BatchNorm3d(32)
        self.batchnorm2 = nn.BatchNorm3d(64)
        self.batchnorm3 = nn.BatchNorm3d(128)
        self.batchnorm4 = nn.BatchNorm3d(256)
        self.batchnorm5 = nn.BatchNorm3d(256)
        self.batchnorm6 = nn.BatchNorm3d(64)

        self.maxpool = nn.MaxPool3d(kernel_size=(2,2,2),stride=(2,2,2))
        self.avgpool = nn.AvgPool3d(kernel_size=(3,4,3),stride=(1,1,1))
        self.dropout = nn.Dropout3d(p=0.5)
        self.relu = nn.ReLU()
        self.classifier = nn.Conv3d(64, n_classes, kernel_size=1, stride=(1,1,1))
        self.softmax = nn.Softmax(dim=1)

    def forward(self,x):
        #input size = (B, 1, 160 , 192, 160)
        x = self.conv1(x)
        x = self.batchnorm1(x)
        x = self.maxpool(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.batchnorm2(x)
        x = self.maxpool(x)
        x = self.relu(x)
        x = self.conv3(x)
        x = self.batchnorm3(x)
        x = self.maxpool(x)
        x = self.relu(x)
        x = self.conv4(x)
        x = self.batchnorm4(x)
        x = self.maxpool(x)
        x = self.relu(x)
        x = self.conv5(x)
        x = self.batchnorm5(x)
        x = self.maxpool(x)
        x = self.relu(x)
        # feature map with (B, 256, 5, 6, 5)
        x = self.conv6(x)
        x = self.batchnorm6(x)
        x = self.relu(x)
        x = self.avgpool(x)
        x = self.dropout(x)
        x = self.classifier(x)
        x = x.view(-1, self._to_linear)
        
        return x
        '''

'class Model3D(nn.Module):\n    def __init__(self, n_classes = 1):\n        super(Model3D,self).__init__()\n        self.conv1 = nn.Conv3d(1, 32, kernel_size=3,stride=(1,1,1),padding=1)\n        self.conv2 = nn.Conv3d(32, 64, kernel_size=3,stride=(1,1,1),padding=1)\n        self.conv3 = nn.Conv3d(64, 128, kernel_size=3,stride=(1,1,1),padding=1)\n        self.conv4 = nn.Conv3d(128, 256, kernel_size=3,stride=(1,1,1),padding=1)\n        self.conv5 = nn.Conv3d(256, 256, kernel_size=3,stride=(1,1,1),padding=1)\n        self.conv6 = nn.Conv3d(256, 64, kernel_size=1,stride=(1,1,1))\n\n        self.batchnorm1 = nn.BatchNorm3d(32)\n        self.batchnorm2 = nn.BatchNorm3d(64)\n        self.batchnorm3 = nn.BatchNorm3d(128)\n        self.batchnorm4 = nn.BatchNorm3d(256)\n        self.batchnorm5 = nn.BatchNorm3d(256)\n        self.batchnorm6 = nn.BatchNorm3d(64)\n\n        self.maxpool = nn.MaxPool3d(kernel_size=(2,2,2),stride=(2,2,2))\n        self.avgpool = nn.AvgPool3d(kernel_size=(3,4,3),strid

In [24]:
class Model3D(nn.Module):
    def __init__(self):
        super(Model3D, self).__init__()
        self.block1 = nn.Sequential(
            nn.Conv3d(1, 8, 3, stride=1),
            nn.ReLU(),
            nn.Conv3d(8, 8, 3, stride=1),
            nn.BatchNorm3d(8),
            nn.ReLU(),
            nn.MaxPool3d(2, stride=2))

        self.block2 = nn.Sequential(
            nn.Conv3d(8, 16, 3, stride=1),
            nn.ReLU(),
            nn.Conv3d(16, 16, 3, stride=1),
            nn.BatchNorm3d(16),
            nn.ReLU(),
            nn.MaxPool3d(2, stride=2))

        self.block3 = nn.Sequential(
            nn.Conv3d(16, 32, 3, stride=1),
            nn.ReLU(),
            nn.Conv3d(32, 32, 3, stride=1),
            nn.BatchNorm3d(32),
            nn.ReLU(),
            nn.MaxPool3d(2, stride=2))

        self.block4 = nn.Sequential(
            nn.Conv3d(32, 64, 3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv3d(64, 64, 3, stride=1, padding=1),
            nn.BatchNorm3d(64),
            nn.ReLU(),
            nn.MaxPool3d(2, stride=2))

        self.block5 = nn.Sequential(
            nn.Conv3d(64, 128, 3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv3d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm3d(128),
            nn.ReLU(),
            nn.MaxPool3d(2, stride=2)
        )

        self.classifier = nn.Linear(1536, 1)


    def forward(self, x):
        x = self.block1(x)
        # print(x.shape)
        x = self.block2(x)
        # print(x.shape)
        x = self.block3(x)
        # print(x.shape)
        x = self.block4(x)
        # print(x.shape)
        x = self.block5(x)
        # print(x.shape)
        x = self.classifier(x.view(-1, x.shape[1]*x.shape[2]*x.shape[3]*x.shape[4]))
        # print(x.shape)
        

        return x

In [25]:
model = Model3D().to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.MSELoss()

In [26]:
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter()

In [27]:
def train(model, train_loader, optimizer):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        
        data, target = data.float().to(DEVICE), target.float().to(DEVICE)
        optimizer.zero_grad()
        output = model(data)
  
        loss = criterion(output.squeeze(), target)
        
        loss.backward()
        optimizer.step()
        



In [28]:
def eval(model, test_loader):
    model.eval()
    test_loss=0
    
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.float().to(DEVICE), target.float().to(DEVICE)
            output = model(data)
           
            
            test_loss += criterion(output.squeeze(), target).item()
        test_loss /= len(test_loader.dataset)
    return test_loss
            
            

In [29]:
import time
import copy

def train_baseline(model, train_loader, val_loader,
                   optimizer, num_epochs = NUM_EPOCHS):
     
   
    
      for epoch in range(1, num_epochs + 1):
            since = time.time()
            train(model, train_loader, optimizer)
            train_loss = eval(model, train_loader)
            writer.add_scalar("loss/train_eval", train_loss, epoch)
            val_loss = eval(model, val_loader)
            writer.add_scalar("loss/val_eval", val_loss, epoch)

            if epoch % 10 == 0:
              for name, param in model.named_parameters():
                writer.add_histogram(name, param.clone().cpu().data.numpy(), epoch)
            time_elapsed = time.time() - since
            print('----------epoch {}----------'.format(epoch))
        
            print('train Loss: {:.4f}'
              .format(train_loss))
            print('val Loss: {:.4f}'
              .format(val_loss))
            print('Completed in {:.0f}m {:.0f}s'
              .format(time_elapsed // 60, time_elapsed % 60))
           
      return model



In [30]:
base = train_baseline(model, trn_loader, val_loader, optimizer, NUM_EPOCHS)
writer.flush()
writer.close()

  return F.mse_loss(input, target, reduction=self.reduction)


----------epoch 1----------
train Loss: 41.0703
val Loss: 40.1788
Completed in 0m 19s
----------epoch 2----------
train Loss: 184.1154
val Loss: 186.5449
Completed in 0m 17s
----------epoch 3----------
train Loss: 17.1760
val Loss: 18.4904
Completed in 0m 16s
----------epoch 4----------
train Loss: 45.7149
val Loss: 48.7135
Completed in 0m 17s
----------epoch 5----------
train Loss: 16.4194
val Loss: 17.8272
Completed in 0m 16s
----------epoch 6----------
train Loss: 20.9814
val Loss: 21.0391
Completed in 0m 16s
----------epoch 7----------
train Loss: 33.3017
val Loss: 36.9703
Completed in 0m 16s
----------epoch 8----------
train Loss: 19.0692
val Loss: 21.9971
Completed in 0m 16s
----------epoch 9----------
train Loss: 19.7203
val Loss: 19.1730
Completed in 0m 16s
----------epoch 10----------
train Loss: 14.6595
val Loss: 15.5902
Completed in 0m 16s
----------epoch 11----------
train Loss: 19.7393
val Loss: 19.8627
Completed in 0m 16s
----------epoch 12----------
train Loss: 13.8737
v

In [31]:
#torch.save(model.state_dict(), './Adam_RMSE3.pth')

In [32]:
writer.close()