In [1]:
import os
import datetime
import glob
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, Subset
import numpy as np
import nibabel as nib
import random
from pathlib import Path
from timm.models.layers import trunc_normal_

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class MRIDataset(Dataset):
    def __init__(self, file_path, path_prefix=""):
        self.path_prefix = path_prefix
        full_file_path = path_prefix + file_path
        with open(full_file_path, 'r') as file:
            data = [line.strip().split() for line in file.readlines()]
        
        self.sub_label = {sub_label: idx for idx, sub_label in enumerate(set(row[1] for row in data))}
        self.cond_label = {cond_label: idx for idx, cond_label in enumerate(set(row[2] for row in data))}
        self.act_label = {act_label: idx for idx, act_label in enumerate(set(row[3] for row in data))}
        self.files = [(row[0], self.sub_label[row[1]], self.cond_label[row[2]], self.act_label[row[3]]) for row in data]
        random.shuffle(self.files)
    
    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img_path, sub_label, cond_label, act_label = self.files[idx]
        full_img_path = self.path_prefix + img_path
        
        img = nib.load(full_img_path).get_fdata()
        img = np.float32(img)
        img = torch.from_numpy(img)
        if img.ndim == 4 and img.shape[-1] == 1:
            img = img.squeeze(-1)
        img = img.unsqueeze(0)
        sub_label = torch.tensor(sub_label, dtype=torch.long)
        cond_label = torch.tensor(cond_label, dtype=torch.long)
        act_label = torch.tensor(act_label, dtype=torch.long)
        return img, sub_label, cond_label, act_label

In [3]:
class Block(nn.Module):
    def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
        super().__init__()
        # Depthwise 3D convolution
        self.dwconv = nn.Conv3d(dim, dim, kernel_size=7, padding=3, groups=dim)
        # Layer normalization for 3D (adjusting for channel dimension)
        self.norm = LayerNorm(dim, eps=1e-6)
        # Pointwise convolutions using linear layers
        self.pwconv1 = nn.Linear(dim, 4 * dim)
        self.act = nn.GELU()
        self.pwconv2 = nn.Linear(4 * dim, dim)
        # Layer scaling if it is utilized
        self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), 
                                  requires_grad=True) if layer_scale_init_value > 0 else None
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x):
        input = x
        x = self.dwconv(x)
        x = x.permute(0, 2, 3, 4, 1)  # Permute to bring channel to last
        x = self.norm(x)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.pwconv2(x)
        if self.gamma is not None:
            x = self.gamma * x
        x = x.permute(0, 4, 1, 2, 3)  # Permute back to normal

        x = input + self.drop_path(x)
        return x

In [4]:
class LayerNorm(nn.Module):
    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps
        self.data_format = data_format
        self.normalized_shape = (normalized_shape,)

    def forward(self, x):
        if self.data_format == "channels_last":
            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        elif self.data_format == "channels_first":
            u = x.mean(1, keepdim=True)
            s = (x - u).pow(2).mean(1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.eps)
            x = self.weight[:, None, None, None] * x + self.bias[:, None, None, None]
            return x

In [5]:
class ConvNeXt(nn.Module):
    def __init__(self, in_chans=1, nc_sub=100, nc_cond=4, nc_act=4, 
                 depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0.,
                 layer_scale_init_value=1e-6, head_init_scale=1.):
        super().__init__()
        # Initial downsampling
        self.downsample_layers = nn.ModuleList()
        stem = nn.Sequential(
            nn.Conv3d(in_chans, dims[0], kernel_size=5, stride=3),
            LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
        )
        self.downsample_layers.append(stem)
        for i in range(3):
            downsample_layer = nn.Sequential(
                LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
                nn.Conv3d(dims[i], dims[i+1], kernel_size=2, stride=2),
            )
            self.downsample_layers.append(downsample_layer)

        self.stages = nn.ModuleList()
        dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
        cur = 0
        for i in range(4):
            stage = nn.Sequential(
                *[Block(dim=dims[i], drop_path=dp_rates[cur + j], 
                layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
            )
            self.stages.append(stage)
            cur += depths[i]

        self.norm = nn.LayerNorm(dims[-1], eps=1e-6)  # final normalization
        
        self.fc_layers = nn.Sequential(
            nn.Linear(dims[-1], 4096),
            nn.GELU(),
            nn.Linear(4096, 4096),
            nn.GELU()
        )
        
        self.head_sub = nn.Linear(4096, nc_sub)
        self.head_cond = nn.Linear(4096, nc_cond)
        self.head_act = nn.Linear(4096, nc_act)

        self.apply(self._init_weights)
        self.head_sub.weight.data.mul_(head_init_scale)
        self.head_sub.bias.data.mul_(head_init_scale)
        self.head_cond.weight.data.mul_(head_init_scale)
        self.head_cond.bias.data.mul_(head_init_scale)
        self.head_act.weight.data.mul_(head_init_scale)
        self.head_act.bias.data.mul_(head_init_scale)

    def _init_weights(self, m):
        if isinstance(m, (nn.Conv3d, nn.Linear)):
            trunc_normal_(m.weight, std=.02)
            nn.init.constant_(m.bias, 0)

    def forward_features(self, x):
        for i in range(4):
            x = self.downsample_layers[i](x)
            x = self.stages[i](x)
        return self.norm(x.mean([-3, -2, -1]))  # global average pooling over spatial dimensions

    def forward(self, x):
        x = self.forward_features(x)
        x = self.fc_layers(x)
        sub_output = self.head_sub(x)
        cond_output = self.head_cond(x)
        act_output = self.head_act(x)
        return sub_output, cond_output, act_output

In [6]:
def train(model, device, train_loader, criterion, optimizer, epoch):
    print(f'Epoch {epoch + 1}, start')
    model.train()
    running_loss = 0.0
    running_log_loss = 0.0
    batch_idx = 0
    for img, sub, cond, act in train_loader:
        batch_idx += 1
        img = img.to(device)
        sub = sub.to(device)
        cond = cond.to(device)
        act = act.to(device)
        optimizer.zero_grad()
        sub_o, cond_o, act_o = model(img)
        loss_sub = criterion(sub_o, sub)
        loss_cond = criterion(cond_o, cond)
        loss_act = criterion(act_o, act)
        loss = loss_sub + loss_cond + loss_act
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        running_log_loss += loss.item()
        if (batch_idx + 1) % 10 == 0:
            current_utc = datetime.datetime.utcnow()
            gmt8_time = current_utc + datetime.timedelta(hours=8)
            current_time = gmt8_time.strftime("%Y-%m-%d %H:%M:%S")
            with open('training_log.txt', 'a') as log_file:
                log_entry = (f'Epoch {epoch+1:02}, Batch {batch_idx+1:04}: '
                             f'Train Loss: {running_log_loss / 10:.4f}, '
                             f'Timestamp: {current_time}\n')
                log_file.write(log_entry)
            running_log_loss = 0
    
    return running_loss / len(train_loader)

In [7]:
def validate(model, device, val_loader, criterion, epoch):
    model.eval()
    running_loss = 0.0
    cor_sub = 0
    cor_cond = 0
    cor_act = 0
    total = 0
    
    with torch.no_grad():
        for img, sub, cond, act in val_loader:
            img = img.to(device)
            sub = sub.to(device)
            cond = cond.to(device)
            act = act.to(device)
            sub_o, cond_o, act_o = model(img)
            loss_sub = criterion(sub_o, sub)
            loss_cond = criterion(cond_o, cond)
            loss_act = criterion(act_o, act)
            loss = loss_sub + loss_cond + loss_act
            running_loss += loss.item()
            
            _, pred_sub = sub_o.max(1)
            _, pred_cond = cond_o.max(1)
            _, pred_act = act_o.max(1)
            cor_sub += pred_sub.eq(sub).sum().item()
            cor_cond += pred_cond.eq(cond).sum().item()
            cor_act += pred_act.eq(act).sum().item()
            total += sub.size(0)
            
    val_loss = running_loss / len(val_loader)
    acc_sub = 100. * cor_sub / total
    acc_cond = 100. * cor_cond / total
    acc_act = 100. * cor_act / total
    
    with open('validation_log.txt', 'a') as log_file:
        current_utc = datetime.datetime.utcnow()
        gmt8_time = current_utc + datetime.timedelta(hours=8)
        current_time = gmt8_time.strftime("%Y-%m-%d %H:%M:%S")
        log_entry = (f'Epoch {epoch+1:03}, Val Loss: {val_loss:.4f}, Val ACC sub: {acc_sub:.2f}%, '
                     f'Val ACC cond: {acc_cond:.2f}%, Val ACC act: {acc_act:.2f}%, Timestamp: {current_time}\n')
        log_file.write(log_entry)

    return val_loss, acc_sub, acc_cond, acc_act

In [8]:
def main(model_name, path_prefix="", epochs=10, lr=0.001, batch_size = 4):
    # Configuration and Hyperparameters
    batch_size = batch_size
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    train_set_file = f'{model_name}+train.txt'
    val_set_file = f'{model_name}+test.txt'
    
    train_dataset = MRIDataset(train_set_file, path_prefix=path_prefix)
    val_dataset = MRIDataset(val_set_file, path_prefix=path_prefix)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    
    nc_sub = len(set(train_dataset.sub_label))
    nc_cond = len(set(train_dataset.cond_label))
    nc_act = len(set(train_dataset.act_label))
    grand_results = []
    
    with open('training_log.txt', 'w') as f:
        f.write("")  # This clears the training log
    with open('validation_log.txt', 'w') as f:
        f.write("")  # This clears the validation log
    
    model = ConvNeXt(in_chans=1, nc_sub=nc_sub, nc_cond=nc_cond, nc_act=nc_act)
    model = model.to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr)
    
    current_utc = datetime.datetime.utcnow()
    gmt8_time = current_utc + datetime.timedelta(hours=8)
    current_time = gmt8_time.strftime("%Y-%m-%d %H:%M:%S")
    start_time = f'Start training at: {current_time}'
    print(start_time)
    
    checkpoints = glob.glob(f'{model_name}_epoch*.pth')
    if checkpoints:
        latest_checkpoint = max(checkpoints, key=os.path.getctime)
        print(f'Loading checkpoint: {latest_checkpoint}')
        model.load_state_dict(torch.load(latest_checkpoint))
        epoch_number = int(latest_checkpoint.split('_epoch')[1].split('.pth')[0])
        start_epoch = epoch_number
    else:
        print('No checkpoint found, starting training from scratch.')
        start_epoch = 0
    
    # Training and Validation Loop
    for epoch in range(start_epoch, epochs):
        train_loss = train(model, device, train_loader, criterion, optimizer, epoch)
        
        val_loss, val_acc_sub, val_acc_cond, val_acc_act = \
        validate(model, device, val_loader, criterion, epoch)
        
        current_utc = datetime.datetime.utcnow()
        gmt8_time = current_utc + datetime.timedelta(hours=8)
        current_time = gmt8_time.strftime("%Y-%m-%d %H:%M:%S")
        start_time = f'Start training at: {current_time}'
        print(f'Epoch {epoch+1:03}, Train Loss: {train_loss:.4f}, Timestamp: {current_time},\n'
              f'Val Loss: {val_loss:.4f}, Val ACC sub: {val_acc_sub:.2f}%, \n'
              f'Val ACC cond: {val_acc_cond:.2f}%, Val ACC act: {val_acc_act:.2f}%')
        
        torch.save(model.state_dict(), f'{model_name}_epoch{epoch+1:03}.pth')

In [9]:
##################################################
N_epoch = 100
N_batch = 12
lr = 0.001

main_path = os.getcwd()

os.chdir(main_path)
classifier_type = 'full'

try:
    os.mkdir(classifier_type)
except FileExistsError:
    print(f"Folder '{classifier_type}' already exists.")

os.chdir(classifier_type)

####################
main(classifier_type, epochs=N_epoch, batch_size = N_batch, lr=lr)

os.chdir(main_path)


Folder 'full' already exists.
Start training at: 2024-06-12 17:50:34
No checkpoint found, starting training from scratch.
Epoch 1, start
Epoch 001, Train Loss: 5.7193, Timestamp: 2024-06-12 17:56:42,
Val Loss: 5.5841, Val ACC sub: 2.36%, 
Val ACC cond: 48.11%, Val ACC act: 60.38%
Epoch 2, start
Epoch 002, Train Loss: 5.3800, Timestamp: 2024-06-12 18:02:19,
Val Loss: 5.5179, Val ACC sub: 2.36%, 
Val ACC cond: 45.75%, Val ACC act: 60.38%
Epoch 3, start
Epoch 003, Train Loss: 5.2455, Timestamp: 2024-06-12 18:07:55,
Val Loss: 5.4190, Val ACC sub: 1.89%, 
Val ACC cond: 40.57%, Val ACC act: 60.38%
Epoch 4, start
Epoch 004, Train Loss: 5.1019, Timestamp: 2024-06-12 18:13:28,
Val Loss: 5.2238, Val ACC sub: 4.72%, 
Val ACC cond: 50.00%, Val ACC act: 60.38%
Epoch 5, start
Epoch 005, Train Loss: 4.9570, Timestamp: 2024-06-12 18:19:01,
Val Loss: 5.0289, Val ACC sub: 7.55%, 
Val ACC cond: 52.36%, Val ACC act: 60.38%
Epoch 6, start
Epoch 006, Train Loss: 4.7781, Timestamp: 2024-06-12 18:24:33,
Val L

Epoch 52, start
Epoch 052, Train Loss: 0.1898, Timestamp: 2024-06-12 22:42:27,
Val Loss: 8.2440, Val ACC sub: 55.19%, 
Val ACC cond: 72.64%, Val ACC act: 58.49%
Epoch 53, start
Epoch 053, Train Loss: 0.1889, Timestamp: 2024-06-12 22:47:53,
Val Loss: 8.2767, Val ACC sub: 50.00%, 
Val ACC cond: 69.34%, Val ACC act: 59.43%
Epoch 54, start
Epoch 054, Train Loss: 0.1446, Timestamp: 2024-06-12 22:53:31,
Val Loss: 8.0991, Val ACC sub: 54.25%, 
Val ACC cond: 75.00%, Val ACC act: 61.79%
Epoch 55, start
Epoch 055, Train Loss: 0.0840, Timestamp: 2024-06-12 22:59:18,
Val Loss: 8.9339, Val ACC sub: 58.96%, 
Val ACC cond: 75.94%, Val ACC act: 61.32%
Epoch 56, start
Epoch 056, Train Loss: 0.1422, Timestamp: 2024-06-12 23:05:06,
Val Loss: 9.3001, Val ACC sub: 52.83%, 
Val ACC cond: 71.23%, Val ACC act: 58.96%
Epoch 57, start
Epoch 057, Train Loss: 0.2600, Timestamp: 2024-06-12 23:10:40,
Val Loss: 7.3882, Val ACC sub: 50.00%, 
Val ACC cond: 75.00%, Val ACC act: 62.26%
Epoch 58, start
Epoch 058, Train L

In [10]:
##################################################
N_ep = 30

main_path = os.getcwd()
errts_path = '../../preprocess/errts'
os.chdir(errts_path)

folder_list = [folder for folder in os.listdir() if folder.startswith('s') and os.path.isdir(folder)]

os.chdir(main_path)
classifier_type = 'condition'

try:
    os.mkdir(classifier_type)
except FileExistsError:
    print(f"Folder '{classifier_type}' already exists.")

os.chdir(classifier_type)

for folder in folder_list:
    # Change to the directory
    os.chdir(main_path)
    try:
        os.mkdir(folder)
    except FileExistsError:
        print(f"Folder '{folder}' already exists.")
    os.chdir(folder)
    path_to_main = '../../'
    full_errts_path = path_to_main + errts_path + '/' + folder + '/'
    main(classifier_type, path_prefix=full_errts_path, epochs=N_ep)
    
    # Return to the parent directory
    os.chdir('..')

Folder 'condition' already exists.
Folder 's11A' already exists.


FileNotFoundError: [Errno 2] No such file or directory: '../../../../preprocess/errts/s11A/condition+train.txt'