In [1]:
import os
import datetime
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, Subset
import numpy as np
import nibabel as nib
import random
from pathlib import Path
from timm.models.layers import trunc_normal_

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class MRIDataset(Dataset):
    def __init__(self, file_path, path_prefix=""):
        self.path_prefix = path_prefix
        full_file_path = path_prefix + file_path
        with open(full_file_path, 'r') as file:
            data = [line.strip().split() for line in file.readlines()]
        
        self.cond_label = {cond_label: idx for idx, cond_label in enumerate(set(row[1] for row in data))}
        self.sub_label = {sub_label: idx for idx, sub_label in enumerate(set(row[2] for row in data))}
        self.files = [(row[0], self.cond_label[row[1]], self.sub_label[row[2]]) for row in data]
        random.shuffle(self.files)
    
    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img_path, cond_label, sub_label = self.files[idx]
        full_img_path = self.path_prefix + img_path
        
        img = nib.load(full_img_path).get_fdata()
        img = np.float32(img)
        img = torch.from_numpy(img)
        if img.ndim == 4 and img.shape[-1] == 1:
            img = img.squeeze(-1)
        img = img.unsqueeze(0)
        cond_label = torch.tensor(cond_label, dtype=torch.long)
        sub_label = torch.tensor(sub_label, dtype=torch.long)
        return img, cond_label, sub_label

In [3]:
class Block(nn.Module):
    def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
        super().__init__()
        # Depthwise 3D convolution
        self.dwconv = nn.Conv3d(dim, dim, kernel_size=7, padding=3, groups=dim)
        # Layer normalization for 3D (adjusting for channel dimension)
        self.norm = LayerNorm(dim, eps=1e-6)
        # Pointwise convolutions using linear layers
        self.pwconv1 = nn.Linear(dim, 4 * dim)
        self.act = nn.GELU()
        self.pwconv2 = nn.Linear(4 * dim, dim)
        # Layer scaling if it is utilized
        self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), 
                                  requires_grad=True) if layer_scale_init_value > 0 else None
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x):
        input = x
        x = self.dwconv(x)
        x = x.permute(0, 2, 3, 4, 1)  # Permute to bring channel to last
        x = self.norm(x)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.pwconv2(x)
        if self.gamma is not None:
            x = self.gamma * x
        x = x.permute(0, 4, 1, 2, 3)  # Permute back to normal

        x = input + self.drop_path(x)
        return x

In [4]:
class LayerNorm(nn.Module):
    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps
        self.data_format = data_format
        self.normalized_shape = (normalized_shape,)

    def forward(self, x):
        if self.data_format == "channels_last":
            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        elif self.data_format == "channels_first":
            u = x.mean(1, keepdim=True)
            s = (x - u).pow(2).mean(1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.eps)
            x = self.weight[:, None, None, None] * x + self.bias[:, None, None, None]
            return x

In [5]:
class ConvNeXt(nn.Module):
    def __init__(self, in_chans=1, nc_cond=8, 
                 depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0.,
                 layer_scale_init_value=1e-6, head_init_scale=1.):
        super().__init__()
        # Initial downsampling
        self.downsample_layers = nn.ModuleList()
        stem = nn.Sequential(
            nn.Conv3d(in_chans, dims[0], kernel_size=5, stride=3),
            LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
        )
        self.downsample_layers.append(stem)
        for i in range(3):
            downsample_layer = nn.Sequential(
                LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
                nn.Conv3d(dims[i], dims[i+1], kernel_size=2, stride=2),
            )
            self.downsample_layers.append(downsample_layer)

        self.stages = nn.ModuleList()
        dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
        cur = 0
        for i in range(4):
            stage = nn.Sequential(
                *[Block(dim=dims[i], drop_path=dp_rates[cur + j], 
                layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
            )
            self.stages.append(stage)
            cur += depths[i]

        self.norm = nn.LayerNorm(dims[-1], eps=1e-6)  # final normalization
        
        self.fc_layers = nn.Sequential(
            nn.Linear(dims[-1], 4096),
            nn.GELU(),
            nn.Linear(4096, 4096),
            nn.GELU()
        )
        
        self.head_cond = nn.Linear(4096, nc_cond)

        self.apply(self._init_weights)
        self.head_cond.weight.data.mul_(head_init_scale)
        self.head_cond.bias.data.mul_(head_init_scale)

    def _init_weights(self, m):
        if isinstance(m, (nn.Conv3d, nn.Linear)):
            trunc_normal_(m.weight, std=.02)
            nn.init.constant_(m.bias, 0)

    def forward_features(self, x):
        for i in range(4):
            x = self.downsample_layers[i](x)
            x = self.stages[i](x)
        return self.norm(x.mean([-3, -2, -1]))  # global average pooling over spatial dimensions

    def forward(self, x):
        x = self.forward_features(x)
        x = self.fc_layers(x)
        cond_output = self.head_cond(x)
        return cond_output

In [6]:
def train(model, device, train_loader, criterion, optimizer, epoch):
    print(f'Epoch {epoch + 1}, start')
    model.train()
    running_loss = 0.0
    running_log_loss = 0.0
    batch_idx = 0
    for img, cond, _ in train_loader:
        batch_idx += 1
        img = img.to(device)
        cond = cond.to(device)
        optimizer.zero_grad()
        cond_o = model(img)
        loss_cond = criterion(cond_o, cond)
        loss = loss_cond
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        running_log_loss += loss.item()
        if (batch_idx + 1) % 1 == 0:
            current_utc = datetime.datetime.utcnow()
            gmt8_time = current_utc + datetime.timedelta(hours=8)
            current_time = gmt8_time.strftime("%Y-%m-%d %H:%M:%S")
            with open('training_log.txt', 'a') as log_file:
                log_entry = (f'Epoch {epoch+1:02}, Batch {batch_idx+1:04}: '
                             f'Train Loss: {running_log_loss / 10:.4f}, '
                             f'Timestamp: {current_time}\n')
                log_file.write(log_entry)
            running_log_loss = 0
    
    return running_loss / len(train_loader)

In [7]:
def validate(model, device, val_loader, criterion, epoch):
    model.eval()
    running_loss = 0.0
    cor_cond = 0
    total = 0
    
    with torch.no_grad():
        for img, cond, _ in val_loader:
            img = img.to(device)
            cond = cond.to(device)
            cond_o = model(img)
            loss_cond = criterion(cond_o, cond)
            loss = loss_cond
            running_loss += loss.item()
            
            _, pred_cond = cond_o.max(1)
            cor_cond += pred_cond.eq(cond).sum().item()
            total += cond.size(0)
            
    val_loss = running_loss / len(val_loader)
    acc_cond = 100. * cor_cond / total
    
    with open('validation_log.txt', 'a') as log_file:
        current_utc = datetime.datetime.utcnow()
        gmt8_time = current_utc + datetime.timedelta(hours=8)
        current_time = gmt8_time.strftime("%Y-%m-%d %H:%M:%S")
        log_entry = (f'Epoch {epoch+1:03}, Val Loss: {val_loss:.4f}, '
                     f'Val ACC cond: {acc_cond:.2f}%, Timestamp: {current_time}\n')
        log_file.write(log_entry)

    return val_loss, acc_cond

In [8]:
def main(model_name, path_prefix="", epochs=10, lr=0.001, batch_size = 4):
    # Configuration and Hyperparameters
    batch_size = batch_size
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    train_set_file = f'train.txt'
    val_set_file = f'test.txt'
    
    train_dataset = MRIDataset(train_set_file, path_prefix=path_prefix)
    val_dataset = MRIDataset(val_set_file, path_prefix=path_prefix)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    
    nc_cond = len(train_dataset.cond_label)
    grand_results = []
    
    with open('training_log.txt', 'w') as f:
        f.write("")  # This clears the training log
    with open('validation_log.txt', 'w') as f:
        f.write("")  # This clears the validation log
    
    model = ConvNeXt(in_chans=1, nc_cond=nc_cond)
    model = model.to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr)
    
    current_utc = datetime.datetime.utcnow()
    gmt8_time = current_utc + datetime.timedelta(hours=8)
    current_time = gmt8_time.strftime("%Y-%m-%d %H:%M:%S")
    start_time = f'Start training at: {current_time}'
    print(start_time)
        
    # Training and Validation Loop
    for epoch in range(epochs):
        train_loss = train(model, device, train_loader, criterion, optimizer, epoch)
        
        val_loss, val_acc_cond = \
        validate(model, device, val_loader, criterion, epoch)
        
        current_utc = datetime.datetime.utcnow()
        gmt8_time = current_utc + datetime.timedelta(hours=8)
        current_time = gmt8_time.strftime("%Y-%m-%d %H:%M:%S")
        start_time = f'Start training at: {current_time}'
        print(f'Epoch {epoch+1:03}, Train Loss: {train_loss:.4f}, Timestamp: {current_time},\n'
              f'Val Loss: {val_loss:.4f}, Val ACC cond: {val_acc_cond:.2f}%')
        
    torch.save(model.state_dict(), f'{model_name}_epoch10.pth')

In [None]:
##################################################
N_epoch = 10
N_batch = 16
lr = 0.0001

main_path = os.getcwd()

classifier_type = 'between_truth'

os.chdir(classifier_type)

# Get the current directory
model_path = os.getcwd()

# Find directories that match the pattern
for folder in os.listdir(model_path):
    if os.path.isdir(folder) and folder.startswith('s') and (folder.endswith('A') or folder.endswith('B')):
        folder_path = os.path.join(model_path, folder)

        # Change to the target directory
        os.chdir(folder_path)
        try:
            # Run the main function
            main(classifier_type, epochs=N_epoch, batch_size = N_batch, lr = lr)
        finally:
            # Change back to the original directory
            os.chdir(model_path)

Start training at: 2024-06-12 13:26:15
Epoch 1, start
Epoch 001, Train Loss: 1.3375, Timestamp: 2024-06-12 13:26:27,
Val Loss: 0.9920, Val ACC cond: 50.00%
Epoch 2, start
Epoch 002, Train Loss: 0.8128, Timestamp: 2024-06-12 13:26:39,
Val Loss: 0.7909, Val ACC cond: 50.00%
Epoch 3, start
Epoch 003, Train Loss: 0.7228, Timestamp: 2024-06-12 13:26:51,
Val Loss: 0.7034, Val ACC cond: 50.00%
Epoch 4, start
Epoch 004, Train Loss: 0.6866, Timestamp: 2024-06-12 13:27:03,
Val Loss: 0.7099, Val ACC cond: 50.00%
Epoch 5, start
Epoch 005, Train Loss: 0.6722, Timestamp: 2024-06-12 13:27:15,
Val Loss: 0.9807, Val ACC cond: 50.00%
Epoch 6, start
Epoch 006, Train Loss: 0.7711, Timestamp: 2024-06-12 13:27:27,
Val Loss: 0.8694, Val ACC cond: 50.00%
Epoch 7, start
Epoch 007, Train Loss: 0.7348, Timestamp: 2024-06-12 13:27:39,
Val Loss: 0.6918, Val ACC cond: 50.00%
Epoch 8, start
Epoch 008, Train Loss: 0.6754, Timestamp: 2024-06-12 13:27:52,
Val Loss: 0.7002, Val ACC cond: 50.00%
Epoch 9, start
Epoch 009,

Epoch 009, Train Loss: 0.6569, Timestamp: 2024-06-12 13:40:12,
Val Loss: 0.6354, Val ACC cond: 50.00%
Epoch 10, start
Epoch 010, Train Loss: 0.6423, Timestamp: 2024-06-12 13:40:25,
Val Loss: 0.6234, Val ACC cond: 50.00%
Start training at: 2024-06-12 13:40:26
Epoch 1, start
Epoch 001, Train Loss: 1.2351, Timestamp: 2024-06-12 13:40:37,
Val Loss: 0.7252, Val ACC cond: 50.00%
Epoch 2, start
Epoch 002, Train Loss: 0.9800, Timestamp: 2024-06-12 13:40:49,
Val Loss: 0.7931, Val ACC cond: 50.00%
Epoch 3, start
Epoch 003, Train Loss: 0.7805, Timestamp: 2024-06-12 13:41:02,
Val Loss: 0.8520, Val ACC cond: 0.00%
Epoch 4, start
Epoch 004, Train Loss: 0.7714, Timestamp: 2024-06-12 13:41:14,
Val Loss: 0.7208, Val ACC cond: 50.00%
Epoch 5, start
Epoch 005, Train Loss: 0.7191, Timestamp: 2024-06-12 13:41:26,
Val Loss: 0.6981, Val ACC cond: 50.00%
Epoch 6, start
Epoch 006, Train Loss: 0.6922, Timestamp: 2024-06-12 13:41:39,
Val Loss: 0.6992, Val ACC cond: 50.00%
Epoch 7, start
Epoch 007, Train Loss: 0.

Epoch 007, Train Loss: 0.6756, Timestamp: 2024-06-12 13:54:00,
Val Loss: 0.6842, Val ACC cond: 50.00%
Epoch 8, start
Epoch 008, Train Loss: 0.6319, Timestamp: 2024-06-12 13:54:11,
Val Loss: 0.6708, Val ACC cond: 50.00%
Epoch 9, start
Epoch 009, Train Loss: 0.6175, Timestamp: 2024-06-12 13:54:23,
Val Loss: 0.6920, Val ACC cond: 50.00%
Epoch 10, start
Epoch 010, Train Loss: 0.5624, Timestamp: 2024-06-12 13:54:34,
Val Loss: 0.6378, Val ACC cond: 100.00%
Start training at: 2024-06-12 13:54:35
Epoch 1, start
Epoch 001, Train Loss: 1.1746, Timestamp: 2024-06-12 13:54:46,
Val Loss: 1.3972, Val ACC cond: 50.00%
Epoch 2, start
Epoch 002, Train Loss: 0.9391, Timestamp: 2024-06-12 13:54:58,
Val Loss: 0.7131, Val ACC cond: 50.00%
Epoch 3, start
Epoch 003, Train Loss: 0.7708, Timestamp: 2024-06-12 13:55:09,
Val Loss: 0.7853, Val ACC cond: 50.00%
Epoch 4, start
Epoch 004, Train Loss: 0.7450, Timestamp: 2024-06-12 13:55:21,
Val Loss: 0.6873, Val ACC cond: 50.00%
Epoch 5, start
Epoch 005, Train Loss: 

Epoch 005, Train Loss: 0.7198, Timestamp: 2024-06-12 14:08:16,
Val Loss: 0.7542, Val ACC cond: 50.00%
Epoch 6, start
Epoch 006, Train Loss: 0.6849, Timestamp: 2024-06-12 14:08:28,
Val Loss: 0.7789, Val ACC cond: 50.00%
Epoch 7, start
Epoch 007, Train Loss: 0.7839, Timestamp: 2024-06-12 14:08:41,
Val Loss: 0.8288, Val ACC cond: 50.00%
Epoch 8, start
Epoch 008, Train Loss: 0.7192, Timestamp: 2024-06-12 14:08:54,
Val Loss: 0.6943, Val ACC cond: 50.00%
Epoch 9, start
Epoch 009, Train Loss: 0.6827, Timestamp: 2024-06-12 14:09:07,
Val Loss: 0.7099, Val ACC cond: 50.00%
Epoch 10, start
Epoch 010, Train Loss: 0.6717, Timestamp: 2024-06-12 14:09:19,
Val Loss: 0.6832, Val ACC cond: 50.00%
Start training at: 2024-06-12 14:09:20
Epoch 1, start
Epoch 001, Train Loss: 1.1618, Timestamp: 2024-06-12 14:09:33,
Val Loss: 0.8600, Val ACC cond: 50.00%
Epoch 2, start
Epoch 002, Train Loss: 0.7561, Timestamp: 2024-06-12 14:09:46,
Val Loss: 0.8328, Val ACC cond: 50.00%
Epoch 3, start


In [None]:
##################################################
N_epoch = 10
N_batch = 12
lr = 0.001

main_path = os.getcwd()

classifier_type = 'between_comp'

os.chdir(classifier_type)

# Get the current directory
model_path = os.getcwd()

# Find directories that match the pattern
for folder in os.listdir(model_path):
    if os.path.isdir(folder) and folder.startswith('s') and (folder.endswith('A') or folder.endswith('B')):
        folder_path = os.path.join(model_path, folder)

        # Change to the target directory
        os.chdir(folder_path)
        try:
            # Run the main function
            main(classifier_type, epochs=N_epoch, batch_size = N_batch, lr = lr)
        finally:
            # Change back to the original directory
            os.chdir(model_path)

In [None]:
##################################################
N_epoch = 5
N_batch = 16
lr = 0.0001

main_path = os.getcwd()

classifier_type = 'between_cond'

os.chdir(classifier_type)

# Get the current directory
model_path = os.getcwd()

# Find directories that match the pattern
for folder in os.listdir(model_path):
    if os.path.isdir(folder) and folder.startswith('s') and (folder.endswith('A') or folder.endswith('B')):
        folder_path = os.path.join(model_path, folder)

        # Change to the target directory
        os.chdir(folder_path)
        try:
            # Run the main function
            main(classifier_type, epochs=N_epoch, batch_size = N_batch, lr = lr)
        finally:
            # Change back to the original directory
            os.chdir(model_path)