In [1]:
import os
import datetime
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, Subset
import numpy as np
import nibabel as nib
import random
from pathlib import Path
from timm.models.layers import trunc_normal_

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class MRIDataset(Dataset):
    def __init__(self, file_path, path_prefix=""):
        self.path_prefix = path_prefix
        full_file_path = path_prefix + file_path
        with open(full_file_path, 'r') as file:
            data = [line.strip().split() for line in file.readlines()]
        
        self.sub_label = {sub_label: idx for idx, sub_label in enumerate(set(row[1] for row in data))}
        self.cond_label = {cond_label: idx for idx, cond_label in enumerate(set(row[2] for row in data))}
        self.stg_label = {stg_label: idx for idx, stg_label in enumerate(set(row[3] for row in data))}
        self.act_label = {act_label: idx for idx, act_label in enumerate(set(row[4] for row in data))}
        self.out_label = {out_label: idx for idx, out_label in enumerate(set(row[5] for row in data))}
        self.files = [(row[0], self.sub_label[row[1]], self.cond_label[row[2]], self.stg_label[row[3]],
                       self.act_label[row[4]], self.out_label[row[5]]) for row in data]
        random.shuffle(self.files)
    
    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img_path, sub_label, cond_label, stg_label, act_label, out_label = self.files[idx]
        full_img_path = self.path_prefix + img_path
        
        img = nib.load(full_img_path).get_fdata()
        img = np.float32(img)
        img = torch.from_numpy(img)
        if img.ndim == 4 and img.shape[-1] == 1:
            img = img.squeeze(-1)
        img = img.unsqueeze(0)
        sub_label = torch.tensor(sub_label, dtype=torch.long)
        cond_label = torch.tensor(cond_label, dtype=torch.long)
        stg_label = torch.tensor(stg_label, dtype=torch.long)
        act_label = torch.tensor(act_label, dtype=torch.long)
        out_label = torch.tensor(out_label, dtype=torch.long)
        return img, sub_label, cond_label, stg_label, act_label, out_label

In [3]:
class Block(nn.Module):
    def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
        super().__init__()
        # Depthwise 3D convolution
        self.dwconv = nn.Conv3d(dim, dim, kernel_size=7, padding=3, groups=dim)
        # Layer normalization for 3D (adjusting for channel dimension)
        self.norm = LayerNorm(dim, eps=1e-6)
        # Pointwise convolutions using linear layers
        self.pwconv1 = nn.Linear(dim, 4 * dim)
        self.act = nn.GELU()
        self.pwconv2 = nn.Linear(4 * dim, dim)
        # Layer scaling if it is utilized
        self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), 
                                  requires_grad=True) if layer_scale_init_value > 0 else None
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x):
        input = x
        x = self.dwconv(x)
        x = x.permute(0, 2, 3, 4, 1)  # Permute to bring channel to last
        x = self.norm(x)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.pwconv2(x)
        if self.gamma is not None:
            x = self.gamma * x
        x = x.permute(0, 4, 1, 2, 3)  # Permute back to normal

        x = input + self.drop_path(x)
        return x

In [4]:
class LayerNorm(nn.Module):
    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps
        self.data_format = data_format
        self.normalized_shape = (normalized_shape,)

    def forward(self, x):
        if self.data_format == "channels_last":
            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        elif self.data_format == "channels_first":
            u = x.mean(1, keepdim=True)
            s = (x - u).pow(2).mean(1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.eps)
            x = self.weight[:, None, None, None] * x + self.bias[:, None, None, None]
            return x

In [5]:
class ConvNeXt(nn.Module):
    def __init__(self, in_chans=1, nc_sub=100, nc_cond=4, nc_stg=4, nc_act=4, nc_out=4, 
                 depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0.,
                 layer_scale_init_value=1e-6, head_init_scale=1.):
        super().__init__()
        # Initial downsampling
        self.downsample_layers = nn.ModuleList()
        stem = nn.Sequential(
            nn.Conv3d(in_chans, dims[0], kernel_size=5, stride=3),
            LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
        )
        self.downsample_layers.append(stem)
        for i in range(3):
            downsample_layer = nn.Sequential(
                LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
                nn.Conv3d(dims[i], dims[i+1], kernel_size=2, stride=2),
            )
            self.downsample_layers.append(downsample_layer)

        self.stages = nn.ModuleList()
        dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
        cur = 0
        for i in range(4):
            stage = nn.Sequential(
                *[Block(dim=dims[i], drop_path=dp_rates[cur + j], 
                layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
            )
            self.stages.append(stage)
            cur += depths[i]

        self.norm = nn.LayerNorm(dims[-1], eps=1e-6)  # final normalization
        
        self.fc_layers = nn.Sequential(
            nn.Linear(dims[-1], 4096),
            nn.GELU(),
            nn.Linear(4096, 4096),
            nn.GELU()
        )
        
        self.head_sub = nn.Linear(4096, nc_sub)
        self.head_cond = nn.Linear(4096, nc_cond)
        self.head_stg = nn.Linear(4096, nc_stg)
        self.head_act = nn.Linear(4096, nc_act)
        self.head_out = nn.Linear(4096, nc_out)

        self.apply(self._init_weights)
        self.head_sub.weight.data.mul_(head_init_scale)
        self.head_sub.bias.data.mul_(head_init_scale)
        self.head_cond.weight.data.mul_(head_init_scale)
        self.head_cond.bias.data.mul_(head_init_scale)
        self.head_stg.weight.data.mul_(head_init_scale)
        self.head_stg.bias.data.mul_(head_init_scale)
        self.head_act.weight.data.mul_(head_init_scale)
        self.head_act.bias.data.mul_(head_init_scale)
        self.head_out.weight.data.mul_(head_init_scale)
        self.head_out.bias.data.mul_(head_init_scale)

    def _init_weights(self, m):
        if isinstance(m, (nn.Conv3d, nn.Linear)):
            trunc_normal_(m.weight, std=.02)
            nn.init.constant_(m.bias, 0)

    def forward_features(self, x):
        for i in range(4):
            x = self.downsample_layers[i](x)
            x = self.stages[i](x)
        return self.norm(x.mean([-3, -2, -1]))  # global average pooling over spatial dimensions

    def forward(self, x):
        x = self.forward_features(x)
        x = self.fc_layers(x)
        sub_output = self.head_sub(x)
        cond_output = self.head_cond(x)
        stg_output = self.head_stg(x)
        act_output = self.head_act(x)
        out_output = self.head_out(x)
        return sub_output, cond_output, stg_output, act_output, out_output

In [6]:
def train(model, device, train_loader, criterion, optimizer, epoch):
    print(f'Epoch {epoch + 1}, start')
    model.train()
    running_loss = 0.0
    running_log_loss = 0.0
    batch_idx = 0
    for img, sub, cond, stg, act, out in train_loader:
        batch_idx += 1
        img = img.to(device)
        sub = sub.to(device)
        cond = cond.to(device)
        stg = stg.to(device)
        act = act.to(device)
        out = out.to(device)
        optimizer.zero_grad()
        sub_o, cond_o, stg_o, act_o, out_o = model(img)
        loss_sub = criterion(sub_o, sub)
        loss_cond = criterion(cond_o, cond)
        loss_stg = criterion(stg_o, stg)
        loss_act = criterion(act_o, act)
        loss_out = criterion(out_o, out)
        loss = loss_sub + loss_cond + loss_stg + loss_act + loss_out
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        running_log_loss += loss.item()
        if (batch_idx + 1) % 10 == 0:
            current_utc = datetime.datetime.utcnow()
            gmt8_time = current_utc + datetime.timedelta(hours=8)
            current_time = gmt8_time.strftime("%Y-%m-%d %H:%M:%S")
            with open('training_log.txt', 'a') as log_file:
                log_entry = (f'Epoch {epoch+1:02}, Batch {batch_idx+1:04}: '
                             f'Train Loss: {running_log_loss / 10:.4f}, '
                             f'Timestamp: {current_time}\n')
                log_file.write(log_entry)
            running_log_loss = 0
    
    return running_loss / len(train_loader)

In [7]:
def validate(model, device, val_loader, criterion, epoch):
    model.eval()
    running_loss = 0.0
    cor_sub = 0
    cor_cond = 0
    cor_stg = 0
    cor_act = 0
    cor_out = 0
    total = 0
    
    with torch.no_grad():
        for img, sub, cond, stg, act, out in val_loader:
            img = img.to(device)
            sub = sub.to(device)
            cond = cond.to(device)
            stg = stg.to(device)
            act = act.to(device)
            out = out.to(device)
            sub_o, cond_o, stg_o, act_o, out_o = model(img)
            loss_sub = criterion(sub_o, sub)
            loss_cond = criterion(cond_o, cond)
            loss_stg = criterion(stg_o, stg)
            loss_act = criterion(act_o, act)
            loss_out = criterion(out_o, out)
            loss = loss_sub + loss_cond + loss_stg + loss_act + loss_out
            running_loss += loss.item()
            
            _, pred_sub = sub_o.max(1)
            _, pred_cond = cond_o.max(1)
            _, pred_stg = stg_o.max(1)
            _, pred_act = act_o.max(1)
            _, pred_out = out_o.max(1)
            cor_sub += pred_sub.eq(sub).sum().item()
            cor_cond += pred_cond.eq(cond).sum().item()
            cor_stg += pred_stg.eq(stg).sum().item()
            cor_act += pred_act.eq(act).sum().item()
            cor_out += pred_out.eq(out).sum().item()
            total += sub.size(0)
            
    val_loss = running_loss / len(val_loader)
    acc_sub = 100. * cor_sub / total
    acc_cond = 100. * cor_cond / total
    acc_stg = 100. * cor_stg / total
    acc_act = 100. * cor_act / total
    acc_out = 100. * cor_out / total
    
    with open('validation_log.txt', 'a') as log_file:
        current_utc = datetime.datetime.utcnow()
        gmt8_time = current_utc + datetime.timedelta(hours=8)
        current_time = gmt8_time.strftime("%Y-%m-%d %H:%M:%S")
        log_entry = (f'Epoch {epoch+1:03}, Val Loss: {val_loss:.4f}, Val ACC sub: {acc_sub:.2f}%, '
                     f'Val ACC cond: {acc_cond:.2f}%, Val ACC stg: {acc_stg:.2f}%, Val ACC act: {acc_act:.2f}%, '
                     f'Val ACC out: {acc_out:.2f}%, Timestamp: {current_time}\n')
        log_file.write(log_entry)

    return val_loss, acc_sub, acc_cond, acc_stg, acc_act, acc_out

In [8]:
def main(model_name, path_prefix="", epochs=10, lr=0.001, batch_size = 4):
    # Configuration and Hyperparameters
    batch_size = batch_size
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    train_set_file = f'{model_name}+train.txt'
    val_set_file = f'{model_name}+val.txt'
    
    train_dataset = MRIDataset(train_set_file, path_prefix=path_prefix)
    val_dataset = MRIDataset(val_set_file, path_prefix=path_prefix)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    
    nc_sub = len(train_dataset.sub_label)
    nc_cond = len(train_dataset.cond_label)
    nc_stg = len(train_dataset.stg_label)
    nc_act = len(train_dataset.act_label)
    nc_out = len(train_dataset.out_label)
    grand_results = []
    
    with open('training_log.txt', 'w') as f:
        f.write("")  # This clears the training log
    with open('validation_log.txt', 'w') as f:
        f.write("")  # This clears the validation log
    
    model = ConvNeXt(in_chans=1, nc_sub=nc_sub, nc_cond=nc_cond, nc_stg=nc_stg, nc_act=nc_act, nc_out=nc_out)
    model = model.to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr)
    
    current_utc = datetime.datetime.utcnow()
    gmt8_time = current_utc + datetime.timedelta(hours=8)
    current_time = gmt8_time.strftime("%Y-%m-%d %H:%M:%S")
    start_time = f'Start training at: {current_time}'
    print(start_time)
        
    # Training and Validation Loop
    for epoch in range(epochs):
        train_loss = train(model, device, train_loader, criterion, optimizer, epoch)
        
        val_loss, val_acc_sub, val_acc_cond, val_acc_stg, val_acc_act, val_acc_out = \
        validate(model, device, val_loader, criterion, epoch)
        
        current_utc = datetime.datetime.utcnow()
        gmt8_time = current_utc + datetime.timedelta(hours=8)
        current_time = gmt8_time.strftime("%Y-%m-%d %H:%M:%S")
        start_time = f'Start training at: {current_time}'
        print(f'Epoch {epoch+1:03}, Train Loss: {train_loss:.4f}, Timestamp: {current_time},\n'
              f'Val Loss: {val_loss:.4f}, Val ACC sub: {val_acc_sub:.2f}%, \n'
              f'Val ACC cond: {val_acc_cond:.2f}%, Val ACC stg: {val_acc_stg:.2f}%, \n'
              f'Val ACC act: {val_acc_act:.2f}%, Val ACC out: {val_acc_out:.2f}%')
        
        torch.save(model.state_dict(), f'{model_name}_epoch{epoch+1:02}.pth')

In [9]:
##################################################
N_epoch = 30
N_batch = 16

main_path = os.getcwd()

os.chdir(main_path)
classifier_type = 'full'

try:
    os.mkdir(classifier_type)
except FileExistsError:
    print(f"Folder '{classifier_type}' already exists.")

os.chdir(classifier_type)

####################
main(classifier_type, epochs=N_epoch, batch_size = N_batch)

os.chdir(main_path)


Folder 'full' already exists.


IndexError: list index out of range

In [None]:
##################################################
N_ep = 30

main_path = os.getcwd()
errts_path = '../../preprocess/errts'
os.chdir(errts_path)

folder_list = [folder for folder in os.listdir() if folder.startswith('s') and os.path.isdir(folder)]

os.chdir(main_path)
classifier_type = 'condition'

try:
    os.mkdir(classifier_type)
except FileExistsError:
    print(f"Folder '{classifier_type}' already exists.")

os.chdir(classifier_type)

for folder in folder_list:
    # Change to the directory
    os.chdir(main_path)
    try:
        os.mkdir(folder)
    except FileExistsError:
        print(f"Folder '{folder}' already exists.")
    os.chdir(folder)
    path_to_main = '../../'
    full_errts_path = path_to_main + errts_path + '/' + folder + '/'
    main(classifier_type, path_prefix=full_errts_path, epochs=N_ep)
    
    # Return to the parent directory
    os.chdir('..')