In [1]:
import os
import datetime
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, Subset
import numpy as np
import nibabel as nib
import random
from pathlib import Path
from timm.models.layers import trunc_normal_

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class MRIDataset(Dataset):
    def __init__(self, file_path, path_prefix=""):
        self.path_prefix = path_prefix
        full_file_path = path_prefix + file_path
        with open(full_file_path, 'r') as file:
            data = [line.strip().split() for line in file.readlines()]
        
        self.cond_label = {cond_label: idx for idx, cond_label in enumerate(set(row[1] for row in data))}
        self.files = [(row[0], self.cond_label[row[1]]) for row in data]
        random.shuffle(self.files)
    
    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img_path, cond_label = self.files[idx]
        full_img_path = self.path_prefix + img_path
        
        img = nib.load(full_img_path).get_fdata()
        img = np.float32(img)
        img = torch.from_numpy(img)
        if img.ndim == 4 and img.shape[-1] == 1:
            img = img.squeeze(-1)
        img = img.unsqueeze(0)
        cond_label = torch.tensor(cond_label, dtype=torch.long)
        return img, cond_label

In [3]:
# class LayerNorm(nn.Module):
#     def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
#         super().__init__()
#         self.weight = nn.Parameter(torch.ones(normalized_shape))
#         self.bias = nn.Parameter(torch.zeros(normalized_shape))
#         self.eps = eps
#         self.data_format = data_format
#         self.normalized_shape = (normalized_shape,)

#     def forward(self, x):
#         if self.data_format == "channels_last":
#             return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
#         elif self.data_format == "channels_first":
#             u = x.mean(1, keepdim=True)
#             s = (x - u).pow(2).mean(1, keepdim=True)
#             x = (x - u) / torch.sqrt(s + self.eps)
#             x = self.weight[:, None, None, None] * x + self.bias[:, None, None, None]
#             return x

class LayerNorm(nn.Module):
    r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. 
    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 
    shape (batch_size, height, width, channels) while channels_first corresponds to inputs 
    with shape (batch_size, channels, height, width).
    """
    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps
        self.data_format = data_format
        if self.data_format not in ["channels_last", "channels_first"]:
            raise NotImplementedError 
        self.normalized_shape = (normalized_shape, )
    
    def forward(self, x):
        if self.data_format == "channels_last":
            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        elif self.data_format == "channels_first":
            u = x.mean(1, keepdim=True)
            s = (x - u).pow(2).mean(1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.eps)
            x = self.weight[:, None, None] * x + self.bias[:, None, None]
            return x

In [4]:
# class Block(nn.Module):
#     def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
#         super().__init__()
#         # Depthwise 3D convolution
#         # self.dwconv = nn.Conv3d(dim, dim, kernel_size=7, padding=3, groups=dim)
#         self.dwconv = Dynamic_conv3d(dim, dim, kernel_size=7, padding=3, groups=dim)
#         # Layer normalization for 3D (adjusting for channel dimension)
#         self.norm = LayerNorm(dim, eps=1e-6)
#         # Pointwise convolutions using linear layers
#         self.pwconv1 = nn.Linear(dim, 4 * dim)
#         self.act = nn.GELU()
#         self.pwconv2 = nn.Linear(4 * dim, dim)
#         # Layer scaling if it is utilized
#         self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), 
#                                   requires_grad=True) if layer_scale_init_value > 0 else None
#         self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

#     def forward(self, x):
#         input = x
#         x = self.dwconv(x)
#         x = x.permute(0, 2, 3, 4, 1)  # Permute to bring channel to last
#         x = self.norm(x)
#         x = self.pwconv1(x)
#         x = self.act(x)
#         x = self.pwconv2(x)
#         if self.gamma is not None:
#             x = self.gamma * x
#         x = x.permute(0, 4, 1, 2, 3)  # Permute back to normal

#         x = input + self.drop_path(x)
#         return x
class Block(nn.Module):
    r""" ConvNeXt Block. There are two equivalent implementations:
    (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
    (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
    We use (2) as we find it slightly faster in PyTorch
    
    Args:
        dim (int): Number of input channels.
        drop_path (float): Stochastic depth rate. Default: 0.0
        layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
    """
    def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
        super().__init__()
        self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
        self.norm = LayerNorm(dim, eps=1e-6)
        self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
        self.act = nn.GELU()
        self.pwconv2 = nn.Linear(4 * dim, dim)
        self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), 
                                    requires_grad=True) if layer_scale_init_value > 0 else None
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x):
        # print("asdasd")
        input = x
        x = self.dwconv(x)
        x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
        x = self.norm(x)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.pwconv2(x)
        if self.gamma is not None:
            x = self.gamma * x
        x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)

        x = input + self.drop_path(x)
        return x

In [5]:
class ConvNeXt(nn.Module):
    def __init__(self, in_chans=64, num_classes=4, 
                 depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0., 
                 layer_scale_init_value=1e-6, head_init_scale=1.,
                 ):
        super().__init__()

        self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
        stem = nn.Sequential(
            nn.Conv2d(in_chans, dims[0], kernel_size=5, stride=3),
            LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
        )
        self.downsample_layers.append(stem)
        for i in range(3):
            downsample_layer = nn.Sequential(
                    LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
                    nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
            )
            self.downsample_layers.append(downsample_layer)

        self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
        dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 
        cur = 0
        for i in range(4):
            stage = nn.Sequential(
                *[Block(dim=dims[i], drop_path=dp_rates[cur + j], 
                layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
            )
            self.stages.append(stage)
            cur += depths[i]

        self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
        self.fc_layers = nn.Sequential(
            nn.Linear(dims[-1], 4096),
            nn.GELU(),
            nn.Linear(4096, 4096),
            nn.GELU()
        )

        self.head = nn.Linear(4096, num_classes)

        self.apply(self._init_weights)
        self.head.weight.data.mul_(head_init_scale)
        self.head.bias.data.mul_(head_init_scale)

    def _init_weights(self, m):
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            trunc_normal_(m.weight, std=.02)
            nn.init.constant_(m.bias, 0)

    def forward_features(self, x):
        for i in range(4):
            # print(x.shape)
            x = self.downsample_layers[i](x)
            # print(x.shape)
            x = self.stages[i](x)
            # print(x.shape)
        return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C)

    def forward(self, x):
        # print(x.shape)
        x = self.forward_features(x)
        # print(x.shape)
        x = self.fc_layers(x)
        x = self.head(x)
        return x

In [6]:
def validate(model, test_input, test_label, device):

    model.eval()

    # Ensure test_input is a tensor and move it to the correct device
    if not torch.is_tensor(test_input):
        test_input = torch.tensor(test_input, dtype=torch.float, device=device)
    else:
        test_input = test_input.to(device)

    # Ensure test_label is a tensor, add a batch dimension, and move to correct device
    if isinstance(test_label, int):
        test_label = torch.tensor([test_label], dtype=torch.long, device=device)
    else:
        test_label = test_label.to(device)

    with torch.no_grad():
        # Perform model inference and get the predicted class
        test_input = torch.squeeze(test_input, dim=1)
        # test_input = test_input.permute(0, 2, 1, 3).contiguous() # channel - y
        test_input = test_input.permute(0, 3, 2, 1).contiguous() # channel - z
        outputs = model(test_input)
        _, predicted = torch.max(outputs, 1)

        # Check if the prediction is correct
        correct = (predicted == test_label).item()  # Convert the result to Python boolean

    return "Correct" if correct else "Incorrect"

In [7]:
def train(model, train_loader, criterion, optimizer, device, test_data=None, epochs=10):
    model.train()
    with open('training_log.txt', 'a') as log_file:
        for epoch in range(epochs):
            total_loss = 0
            num_batches = 0
            for img, cond in train_loader:
                img = img.to(device)
                cond = cond.to(device)
                optimizer.zero_grad()
                img = torch.squeeze(img, dim=1)
                # 將 H 和通道維度對調
                # img = img.permute(0, 2, 1, 3).contiguous() # channel - y
                img = img.permute(0, 3, 2, 1).contiguous() # channel - z
                # print(img.shape)
                cond_o = model(img)
                loss_cond = criterion(cond_o, cond)
                loss = loss_cond
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
                num_batches += 1
            average_loss = total_loss / num_batches
            current_utc = datetime.datetime.utcnow()
            gmt8_time = current_utc + datetime.timedelta(hours=8)
            current_time = gmt8_time.strftime("%Y-%m-%d %H:%M:%S")
            log_entry = f'Epoch {epoch+1:03}, Average Loss: {average_loss}, Timestamp: {current_time}\n'
            # Write the log entry to the file
            log_file.write(log_entry)
            # Test every 10 epochs
            if (epoch + 1) % 10 == 0 and test_data is not None:
                test_input, test_label = test_data
                test_result = validate(model, test_input, test_label, device)
                log_file.write(f"Test at Epoch {epoch+1:03}: {test_result}\n")
    

In [8]:
def main(classifier, path_prefix="", epochs=10, lr=0.001, batch=8):
    # Configuration and Hyperparameters
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f'Using device: {device}')
    
    dataset_file = f'{classifier}+classify.txt'
    
    full_dataset = MRIDataset(dataset_file, path_prefix=path_prefix)
    nc_cond = len(full_dataset.cond_label)
    grand_results = []
    
    current_utc = datetime.datetime.utcnow()
    gmt8_time = current_utc + datetime.timedelta(hours=8)
    current_time = gmt8_time.strftime("%Y-%m-%d %H:%M:%S")
    # current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    start_time = f'Start training at: {current_time}'
    print(start_time)
    
    for i in range(len(full_dataset)):
        train_indices = list(range(len(full_dataset)))
        train_indices.pop(i)  # Remove the test image index
        test_index = i
        
        train_subset = Subset(full_dataset, train_indices)
        test_input, test_label = full_dataset[test_index]

        train_loader = DataLoader(train_subset, batch_size=batch, shuffle=True)
            
        model = ConvNeXt(in_chans=64, num_classes=nc_cond)
        model = model.to(device)
        
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=lr)

        train(model, train_loader, criterion, optimizer, device, test_data=full_dataset[test_index], epochs=epochs)
        result = validate(model, test_input, test_label, device)
        grand_results.append(result)
        
        os.rename('training_log.txt', f'training_log+stim_{i+1:03}.txt')
    
    with open(f'{classifier}_final_results.log', 'w') as f:
        correct_count = grand_results.count("Correct")
        total_tests = len(grand_results)
        correct_percentage = (correct_count / total_tests) * 100 if total_tests > 0 else 0
        for idx, result in enumerate(grand_results):
            f.write(f"Model {idx+1:03}: Result: {result}\n")
        f.write(f"Percentage of Correct Predictions: {correct_percentage:.2f}%\n")

    print(f"Percentage of Correct Predictions: {correct_percentage:.2f}%")
        
    torch.save(model.state_dict(), f'{classifier}_epoch{epochs:03}.pth')

In [9]:
##################################################
N_ep = 50
N_batch = 8

main_path = os.getcwd()
classifier_type = './condition'
os.chdir(classifier_type)

folder_list = [folder for folder in os.listdir() if folder.startswith('s') and os.path.isdir(folder)]
print("channel = z")
for folder in folder_list:
    os.chdir(folder)
    main(classifier_type, epochs=N_ep, batch=N_batch)
    
    # Return to the parent directory
    os.chdir('..')

channel = z
Using device: cuda
Start training at: 2024-06-16 20:03:57
Percentage of Correct Predictions: 63.16%
Using device: cuda
Start training at: 2024-06-16 20:07:54
Percentage of Correct Predictions: 73.17%
Using device: cuda
Start training at: 2024-06-16 20:12:12
Percentage of Correct Predictions: 63.64%
Using device: cuda
Start training at: 2024-06-16 20:17:31
Percentage of Correct Predictions: 87.80%
Using device: cuda
Start training at: 2024-06-16 20:21:48
Percentage of Correct Predictions: 63.16%
Using device: cuda
Start training at: 2024-06-16 20:25:43
Percentage of Correct Predictions: 62.22%
Using device: cuda
Start training at: 2024-06-16 20:31:11
Percentage of Correct Predictions: 79.49%
