In [1]:
import os
import datetime
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, Subset
import numpy as np
import nibabel as nib
import random
from pathlib import Path
from timm.models.layers import trunc_normal_

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class MRIDataset(Dataset):
    def __init__(self, file_path, path_prefix=""):
        self.path_prefix = path_prefix
        full_file_path = file_path
        with open(full_file_path, 'r') as file:
            data = [line.strip().split() for line in file.readlines()]
        
        self.cond_label = {cond_label: idx for idx, cond_label in enumerate(set(row[1] for row in data))}
        self.files = [(row[0], self.cond_label[row[1]]) for row in data]
        random.shuffle(self.files)
    
    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img_path, cond_label = self.files[idx]
        full_img_path = self.path_prefix + img_path
        
        img = nib.load(full_img_path).get_fdata()
        img = np.float32(img)
        img = torch.from_numpy(img)
        if img.ndim == 4 and img.shape[-1] == 1:
            img = img.squeeze(-1)
        img = img.unsqueeze(0)
        cond_label = torch.tensor(cond_label, dtype=torch.long)
        return img, cond_label

In [3]:
class Block(nn.Module):
    def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
        super().__init__()
        # Depthwise 3D convolution
        self.dwconv = nn.Conv3d(dim, dim, kernel_size=7, padding=3, groups=dim)
        # Layer normalization for 3D (adjusting for channel dimension)
        self.norm = LayerNorm(dim, eps=1e-6)
        # Pointwise convolutions using linear layers
        self.pwconv1 = nn.Linear(dim, 4 * dim)
        self.act = nn.GELU()
        self.pwconv2 = nn.Linear(4 * dim, dim)
        # Layer scaling if it is utilized
        self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), 
                                  requires_grad=True) if layer_scale_init_value > 0 else None
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x):
        input = x
        x = self.dwconv(x)
        x = x.permute(0, 2, 3, 4, 1)  # Permute to bring channel to last
        x = self.norm(x)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.pwconv2(x)
        if self.gamma is not None:
            x = self.gamma * x
        x = x.permute(0, 4, 1, 2, 3)  # Permute back to normal

        x = input + self.drop_path(x)
        return x

In [4]:
class LayerNorm(nn.Module):
    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps
        self.data_format = data_format
        self.normalized_shape = (normalized_shape,)

    def forward(self, x):
        if self.data_format == "channels_last":
            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        elif self.data_format == "channels_first":
            u = x.mean(1, keepdim=True)
            s = (x - u).pow(2).mean(1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.eps)
            x = self.weight[:, None, None, None] * x + self.bias[:, None, None, None]
            return x

In [5]:
class ConvNeXt(nn.Module):
    def __init__(self, in_chans=1, nc_cond=8, 
                 depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0.,
                 layer_scale_init_value=1e-6, head_init_scale=1.):
        super().__init__()
        # Initial downsampling
        self.downsample_layers = nn.ModuleList()
        stem = nn.Sequential(
            nn.Conv3d(in_chans, dims[0], kernel_size=5, stride=3),
            LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
        )
        self.downsample_layers.append(stem)
        for i in range(3):
            downsample_layer = nn.Sequential(
                LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
                nn.Conv3d(dims[i], dims[i+1], kernel_size=2, stride=2),
            )
            self.downsample_layers.append(downsample_layer)

        self.stages = nn.ModuleList()
        dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
        cur = 0
        for i in range(4):
            stage = nn.Sequential(
                *[Block(dim=dims[i], drop_path=dp_rates[cur + j], 
                layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
            )
            self.stages.append(stage)
            cur += depths[i]

        self.norm = nn.LayerNorm(dims[-1], eps=1e-6)  # final normalization
        
        self.fc_layers = nn.Sequential(
            nn.Linear(dims[-1], 4096),
            nn.GELU(),
            nn.Linear(4096, 4096),
            nn.GELU()
        )
        
        self.head_cond = nn.Linear(4096, nc_cond)

        self.apply(self._init_weights)
        self.head_cond.weight.data.mul_(head_init_scale)
        self.head_cond.bias.data.mul_(head_init_scale)

    def _init_weights(self, m):
        if isinstance(m, (nn.Conv3d, nn.Linear)):
            trunc_normal_(m.weight, std=.02)
            nn.init.constant_(m.bias, 0)

    def forward_features(self, x):
        for i in range(4):
            x = self.downsample_layers[i](x)
            x = self.stages[i](x)
        return self.norm(x.mean([-3, -2, -1]))  # global average pooling over spatial dimensions

    def forward(self, x):
        x = self.forward_features(x)
        x = self.fc_layers(x)
        cond_output = self.head_cond(x)
        return cond_output

In [6]:
def train(model, device, train_loader, criterion, optimizer, epoch):
    print(f'Epoch {epoch + 1}, start')
    model.train()
    running_loss = 0.0
    running_log_loss = 0.0
    batch_idx = 0
    for img, cond in train_loader:
        batch_idx += 1
        img = img.to(device)
        cond = cond.to(device)
        optimizer.zero_grad()
        cond_o = model(img)
        loss_cond = criterion(cond_o, cond)
        loss = loss_cond
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        running_log_loss += loss.item()
        if (batch_idx + 1) % 1 == 0:
            current_utc = datetime.datetime.utcnow()
            gmt8_time = current_utc + datetime.timedelta(hours=8)
            current_time = gmt8_time.strftime("%Y-%m-%d %H:%M:%S")
            with open('training_log.txt', 'a') as log_file:
                log_entry = (f'Epoch {epoch+1:02}, Batch {batch_idx+1:04}: '
                             f'Train Loss: {running_log_loss / 10:.4f}, '
                             f'Timestamp: {current_time}\n')
                log_file.write(log_entry)
            running_log_loss = 0
    
    return running_loss / len(train_loader)

In [7]:
def validate(model, device, val_loader, criterion, epoch):
    model.eval()
    running_loss = 0.0
    cor_cond = 0
    total = 0
    
    with torch.no_grad():
        for img, cond in val_loader:
            img = img.to(device)
            cond = cond.to(device)
            cond_o = model(img)
            loss_cond = criterion(cond_o, cond)
            loss = loss_cond
            running_loss += loss.item()
            
            _, pred_cond = cond_o.max(1)
            cor_cond += pred_cond.eq(cond).sum().item()
            total += cond.size(0)
            
    val_loss = running_loss / len(val_loader)
    acc_cond = 100. * cor_cond / total
    
    with open('validation_log.txt', 'a') as log_file:
        current_utc = datetime.datetime.utcnow()
        gmt8_time = current_utc + datetime.timedelta(hours=8)
        current_time = gmt8_time.strftime("%Y-%m-%d %H:%M:%S")
        log_entry = (f'Epoch {epoch+1:03}, Val Loss: {val_loss:.4f}, '
                     f'Val ACC cond: {acc_cond:.2f}%, Timestamp: {current_time}\n')
        log_file.write(log_entry)

    return val_loss, acc_cond

In [8]:
def BD_validate(model, device, val_loader, criterion, epoch):
    model.eval()
    running_loss = 0.0
    cor_cond = 0
    total = 0
    
    with torch.no_grad():
        for img, cond in val_loader:
            img = img.to(device)
            cond = cond.to(device)
            cond_o = model(img)
            loss_cond = criterion(cond_o, cond)
            loss = loss_cond
            running_loss += loss.item()
            
            _, pred_cond = cond_o.max(1)
            cor_cond += pred_cond.eq(cond).sum().item()
            total += cond.size(0)
            
    val_loss = running_loss / len(val_loader)
    acc_cond = 100. * cor_cond / total
    
    with open('BD_validation_log.txt', 'a') as log_file:
        current_utc = datetime.datetime.utcnow()
        gmt8_time = current_utc + datetime.timedelta(hours=8)
        current_time = gmt8_time.strftime("%Y-%m-%d %H:%M:%S")
        log_entry = (f'Epoch {epoch+1:03}, BD Loss: {val_loss:.4f}, '
                     f'BD ACC cond: {acc_cond:.2f}%, Timestamp: {current_time}\n')
        log_file.write(log_entry)

    return val_loss, acc_cond

In [9]:
def main(model_name, path_prefix="", epochs=10, lr=0.001, batch_size = 4):
    # Configuration and Hyperparameters
    batch_size = batch_size
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    train_set_file = f'train.txt'
    val_set_file = f'test.txt'
    BD_set_file = f'BD_test.txt'
    
    glm_trick_path = path_prefix + 'glm_trick/'
    
    train_dataset = MRIDataset(train_set_file, path_prefix=glm_trick_path)
    val_dataset = MRIDataset(val_set_file, path_prefix=glm_trick_path)
    BD_dataset = MRIDataset(BD_set_file, path_prefix=path_prefix)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    BD_loader = DataLoader(BD_dataset, batch_size=batch_size, shuffle=False)
    
    nc_cond = len(train_dataset.cond_label)
    grand_results = []
    
    with open('training_log.txt', 'w') as f:
        f.write("")  # This clears the training log
    with open('validation_log.txt', 'w') as f:
        f.write("")  # This clears the validation log
    with open('BD_validation_log.txt', 'w') as f:
        f.write("")  # This clears the validation log
    
    model = ConvNeXt(in_chans=1, nc_cond=nc_cond)
    model = model.to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr)
    
    current_utc = datetime.datetime.utcnow()
    gmt8_time = current_utc + datetime.timedelta(hours=8)
    current_time = gmt8_time.strftime("%Y-%m-%d %H:%M:%S")
    start_time = f'Start training at: {current_time}'
    print(start_time)
        
    # Training and Validation Loop
    for epoch in range(epochs):
        train_loss = train(model, device, train_loader, criterion, optimizer, epoch)
        
        val_loss, val_acc_cond = \
        validate(model, device, val_loader, criterion, epoch)
        
        BD_loss, BD_acc_cond = \
        BD_validate(model, device, BD_loader, criterion, epoch)
        
        current_utc = datetime.datetime.utcnow()
        gmt8_time = current_utc + datetime.timedelta(hours=8)
        current_time = gmt8_time.strftime("%Y-%m-%d %H:%M:%S")
        start_time = f'Start training at: {current_time}'
        print(f'Epoch {epoch+1:03}, Train Loss: {train_loss:.4f}, Timestamp: {current_time},\n'
              f'Val Loss: {val_loss:.4f}, Val ACC cond: {val_acc_cond:.2f},\n'
              f'BD Loss: {BD_loss:.4f}, BD ACC cond: {BD_acc_cond:.2f}%')
        
        if (epoch + 1) % 10 == 0 :
            torch.save(model.state_dict(), f'{model_name}_epoch{epoch+1:03}.pth')
        

In [10]:
##################################################
N_ep = 50
N_batch = 8
lr = 0.001

main_path = os.getcwd()
errts_path = '../../preprocess/errts'
classifier_type = 'glm_trick'
os.chdir(classifier_type)

folder_list = [folder for folder in os.listdir() if folder.startswith('s') and os.path.isdir(folder)]

for folder in folder_list:
    os.chdir(folder)
    path_to_main = '../../'
    
    full_errts_path = path_to_main + errts_path + '/' + folder + '/'
    main(classifier_type, path_prefix=full_errts_path, epochs=N_ep, batch_size = N_batch, lr = lr)
    
    # Return to the parent directory
    os.chdir('..')

main_path = os.getcwd()

Start training at: 2024-06-14 11:52:03
Epoch 1, start
Epoch 001, Train Loss: 4.7068, Timestamp: 2024-06-14 11:52:16,
Val Loss: 2.0133, Val ACC cond: 33.33,
BD Loss: 0.8191, BD ACC cond: 68.42%
Epoch 2, start
Epoch 002, Train Loss: 0.7207, Timestamp: 2024-06-14 11:52:28,
Val Loss: 0.5998, Val ACC cond: 66.67,
BD Loss: 0.7700, BD ACC cond: 31.58%
Epoch 3, start
Epoch 003, Train Loss: 0.7078, Timestamp: 2024-06-14 11:52:40,
Val Loss: 0.5844, Val ACC cond: 66.67,
BD Loss: 0.7762, BD ACC cond: 31.58%
Epoch 4, start
Epoch 004, Train Loss: 0.7452, Timestamp: 2024-06-14 11:52:53,
Val Loss: 0.5526, Val ACC cond: 66.67,
BD Loss: 0.8392, BD ACC cond: 31.58%
Epoch 5, start
Epoch 005, Train Loss: 0.7114, Timestamp: 2024-06-14 11:53:05,
Val Loss: 0.8520, Val ACC cond: 33.33,
BD Loss: 0.6174, BD ACC cond: 68.42%
Epoch 6, start
Epoch 006, Train Loss: 0.5905, Timestamp: 2024-06-14 11:53:18,
Val Loss: 0.4896, Val ACC cond: 66.67,
BD Loss: 0.8139, BD ACC cond: 36.84%
Epoch 7, start
Epoch 007, Train Loss:

Epoch 004, Train Loss: 0.6981, Timestamp: 2024-06-14 12:03:47,
Val Loss: 0.7228, Val ACC cond: 42.86,
BD Loss: 0.6687, BD ACC cond: 62.22%
Epoch 5, start
Epoch 005, Train Loss: 0.6918, Timestamp: 2024-06-14 12:03:59,
Val Loss: 0.7081, Val ACC cond: 42.86,
BD Loss: 0.6797, BD ACC cond: 62.22%
Epoch 6, start
Epoch 006, Train Loss: 0.6892, Timestamp: 2024-06-14 12:04:12,
Val Loss: 0.7280, Val ACC cond: 42.86,
BD Loss: 0.6673, BD ACC cond: 62.22%
Epoch 7, start
Epoch 007, Train Loss: 0.6927, Timestamp: 2024-06-14 12:04:26,
Val Loss: 0.7265, Val ACC cond: 42.86,
BD Loss: 0.6670, BD ACC cond: 62.22%
Epoch 8, start
Epoch 008, Train Loss: 0.7113, Timestamp: 2024-06-14 12:04:39,
Val Loss: 0.7012, Val ACC cond: 42.86,
BD Loss: 0.6857, BD ACC cond: 62.22%
Epoch 9, start
Epoch 009, Train Loss: 0.7058, Timestamp: 2024-06-14 12:04:52,
Val Loss: 0.7086, Val ACC cond: 42.86,
BD Loss: 0.6761, BD ACC cond: 62.22%
Epoch 10, start
Epoch 010, Train Loss: 0.7081, Timestamp: 2024-06-14 12:05:05,
Val Loss: 0.

Epoch 007, Train Loss: 0.6949, Timestamp: 2024-06-14 12:15:11,
Val Loss: 0.6788, Val ACC cond: 50.00,
BD Loss: 0.6889, BD ACC cond: 53.66%
Epoch 8, start
Epoch 008, Train Loss: 0.6887, Timestamp: 2024-06-14 12:15:22,
Val Loss: 0.6442, Val ACC cond: 80.00,
BD Loss: 0.6920, BD ACC cond: 43.90%
Epoch 9, start
Epoch 009, Train Loss: 0.6983, Timestamp: 2024-06-14 12:15:34,
Val Loss: 0.6689, Val ACC cond: 50.00,
BD Loss: 0.6883, BD ACC cond: 53.66%
Epoch 10, start
Epoch 010, Train Loss: 0.6934, Timestamp: 2024-06-14 12:15:45,
Val Loss: 0.6470, Val ACC cond: 70.00,
BD Loss: 0.6880, BD ACC cond: 58.54%
Epoch 11, start
Epoch 011, Train Loss: 0.6832, Timestamp: 2024-06-14 12:15:57,
Val Loss: 0.6275, Val ACC cond: 80.00,
BD Loss: 0.6867, BD ACC cond: 58.54%
Epoch 12, start
Epoch 012, Train Loss: 0.6988, Timestamp: 2024-06-14 12:16:08,
Val Loss: 0.6026, Val ACC cond: 80.00,
BD Loss: 0.6842, BD ACC cond: 58.54%
Epoch 13, start
Epoch 013, Train Loss: 0.7242, Timestamp: 2024-06-14 12:16:19,
Val Loss:

Epoch 010, Train Loss: 0.2970, Timestamp: 2024-06-14 12:25:31,
Val Loss: 0.4740, Val ACC cond: 87.50,
BD Loss: 0.8974, BD ACC cond: 65.91%
Epoch 11, start
Epoch 011, Train Loss: 0.1509, Timestamp: 2024-06-14 12:25:44,
Val Loss: 0.4077, Val ACC cond: 87.50,
BD Loss: 1.3095, BD ACC cond: 75.00%
Epoch 12, start
Epoch 012, Train Loss: 0.1340, Timestamp: 2024-06-14 12:25:56,
Val Loss: 1.8800, Val ACC cond: 75.00,
BD Loss: 2.7573, BD ACC cond: 75.00%
Epoch 13, start
Epoch 013, Train Loss: 0.0679, Timestamp: 2024-06-14 12:26:08,
Val Loss: 0.5622, Val ACC cond: 87.50,
BD Loss: 2.2093, BD ACC cond: 72.73%
Epoch 14, start
Epoch 014, Train Loss: 0.0653, Timestamp: 2024-06-14 12:26:21,
Val Loss: 1.6911, Val ACC cond: 87.50,
BD Loss: 2.5277, BD ACC cond: 75.00%
Epoch 15, start
Epoch 015, Train Loss: 0.1024, Timestamp: 2024-06-14 12:26:33,
Val Loss: 1.5568, Val ACC cond: 62.50,
BD Loss: 2.3668, BD ACC cond: 59.09%
Epoch 16, start
Epoch 016, Train Loss: 0.3278, Timestamp: 2024-06-14 12:26:45,
Val Los

Epoch 013, Train Loss: 0.6829, Timestamp: 2024-06-14 12:36:27,
Val Loss: 0.6796, Val ACC cond: 62.50,
BD Loss: 0.6928, BD ACC cond: 53.85%
Epoch 14, start
Epoch 014, Train Loss: 0.6812, Timestamp: 2024-06-14 12:36:39,
Val Loss: 0.7181, Val ACC cond: 25.00,
BD Loss: 0.6519, BD ACC cond: 69.23%
Epoch 15, start
Epoch 015, Train Loss: 0.6795, Timestamp: 2024-06-14 12:36:51,
Val Loss: 0.7532, Val ACC cond: 37.50,
BD Loss: 0.6230, BD ACC cond: 71.79%
Epoch 16, start
Epoch 016, Train Loss: 0.6736, Timestamp: 2024-06-14 12:37:04,
Val Loss: 0.7432, Val ACC cond: 37.50,
BD Loss: 0.6153, BD ACC cond: 71.79%
Epoch 17, start
Epoch 017, Train Loss: 0.6906, Timestamp: 2024-06-14 12:37:16,
Val Loss: 0.7544, Val ACC cond: 25.00,
BD Loss: 0.6202, BD ACC cond: 71.79%
Epoch 18, start
Epoch 018, Train Loss: 0.6800, Timestamp: 2024-06-14 12:37:28,
Val Loss: 0.7084, Val ACC cond: 25.00,
BD Loss: 0.6395, BD ACC cond: 58.97%
Epoch 19, start
Epoch 019, Train Loss: 0.6845, Timestamp: 2024-06-14 12:37:40,
Val Los

Epoch 016, Train Loss: 0.6936, Timestamp: 2024-06-14 12:47:10,
Val Loss: 0.6986, Val ACC cond: 28.57,
BD Loss: 0.6888, BD ACC cond: 71.05%
Epoch 17, start
Epoch 017, Train Loss: 0.6928, Timestamp: 2024-06-14 12:47:22,
Val Loss: 0.6968, Val ACC cond: 28.57,
BD Loss: 0.6900, BD ACC cond: 71.05%
Epoch 18, start
Epoch 018, Train Loss: 0.6923, Timestamp: 2024-06-14 12:47:34,
Val Loss: 0.6999, Val ACC cond: 28.57,
BD Loss: 0.6870, BD ACC cond: 71.05%
Epoch 19, start
Epoch 019, Train Loss: 0.6917, Timestamp: 2024-06-14 12:47:47,
Val Loss: 0.7037, Val ACC cond: 28.57,
BD Loss: 0.6836, BD ACC cond: 71.05%
Epoch 20, start
Epoch 020, Train Loss: 0.6907, Timestamp: 2024-06-14 12:47:59,
Val Loss: 0.7166, Val ACC cond: 28.57,
BD Loss: 0.6728, BD ACC cond: 71.05%
Epoch 21, start
Epoch 021, Train Loss: 0.6934, Timestamp: 2024-06-14 12:48:12,
Val Loss: 0.7304, Val ACC cond: 28.57,
BD Loss: 0.6626, BD ACC cond: 71.05%
Epoch 22, start
Epoch 022, Train Loss: 0.6937, Timestamp: 2024-06-14 12:48:24,
Val Los

Epoch 019, Train Loss: 0.6878, Timestamp: 2024-06-14 12:57:55,
Val Loss: 0.6922, Val ACC cond: 55.56,
BD Loss: 0.6391, BD ACC cond: 65.85%
Epoch 20, start
Epoch 020, Train Loss: 0.6572, Timestamp: 2024-06-14 12:58:08,
Val Loss: 0.7936, Val ACC cond: 33.33,
BD Loss: 0.7389, BD ACC cond: 46.34%
Epoch 21, start
Epoch 021, Train Loss: 0.6003, Timestamp: 2024-06-14 12:58:21,
Val Loss: 0.8518, Val ACC cond: 44.44,
BD Loss: 0.7100, BD ACC cond: 73.17%
Epoch 22, start
Epoch 022, Train Loss: 0.6441, Timestamp: 2024-06-14 12:58:32,
Val Loss: 1.0667, Val ACC cond: 33.33,
BD Loss: 0.7428, BD ACC cond: 68.29%
Epoch 23, start
Epoch 023, Train Loss: 0.5986, Timestamp: 2024-06-14 12:58:44,
Val Loss: 0.8764, Val ACC cond: 55.56,
BD Loss: 0.5893, BD ACC cond: 73.17%
Epoch 24, start
Epoch 024, Train Loss: 0.4263, Timestamp: 2024-06-14 12:58:56,
Val Loss: 0.7192, Val ACC cond: 66.67,
BD Loss: 0.3657, BD ACC cond: 85.37%
Epoch 25, start
Epoch 025, Train Loss: 0.7311, Timestamp: 2024-06-14 12:59:09,
Val Los

In [11]:
##################################################
N_epoch = 10
N_batch = 16
lr = 0.0001

main_path = os.getcwd()

classifier_type = 'between_truth'

os.chdir(classifier_type)

# Get the current directory
model_path = os.getcwd()

# Find directories that match the pattern
for folder in os.listdir(model_path):
    if os.path.isdir(folder) and folder.startswith('s') and (folder.endswith('A') or folder.endswith('B')):
        folder_path = os.path.join(model_path, folder)

        # Change to the target directory
        os.chdir(folder_path)
        try:
            # Run the main function
            main(classifier_type, epochs=N_epoch, batch_size = N_batch, lr = lr)
        finally:
            # Change back to the original directory
            os.chdir(model_path)

FileNotFoundError: [Errno 2] No such file or directory: 'between_truth'