In [1]:
##### this is for classification in single subject

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
import numpy as np
import os
import nibabel as nib
import random
import datetime
from torch.utils.tensorboard import SummaryWriter
import logging
import torch.optim as optim
from sklearn.metrics import precision_recall_fscore_support

2024-05-13 08:48:30.784639: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-05-13 08:48:30.784680: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-05-13 08:48:30.785100: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-05-13 08:48:30.787982: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
class MRIDataset(Dataset):
    def __init__(self, file_path, path_prefix=""):
        # Initialize with an optional path prefix
        self.path_prefix = path_prefix
        
        full_file_path = path_prefix + file_path
        with open(full_file_path, 'r') as file:
            data = [line.strip().split() for line in file.readlines()]
        
        self.labels = {label: idx for idx, label in enumerate(set(row[1] for row in data))}
        self.files = [(row[0], self.labels[row[1]]) for row in data]
        random.shuffle(self.files)  # Shuffle the list of files
    
    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img_path, label = self.files[idx]
        full_img_path = self.path_prefix + img_path
        img = nib.load(full_img_path).get_fdata()
        img = np.float32(img)
        img = torch.from_numpy(img)
        if img.ndim == 4 and img.shape[-1] == 1:
            img = img.squeeze(-1)
        img = img.unsqueeze(0)  # Ensure channel dimension is present
        return img, label

In [4]:
class ConvNeXtT(nn.Module):
    def __init__(self,num_classes):
        super(ConvNeXtT, self).__init__()
        self.conv1 = nn.Conv3d(1, 16, kernel_size=3, stride=1, padding=1)
        self.ln1 = nn.LayerNorm([16, 64, 76, 64])
        self.gelu = nn.GELU()
        self.conv2 = nn.Conv3d(16, 32, kernel_size=3, stride=2, padding=1)
        self.ln2 = nn.LayerNorm([32, 32, 38, 32])
        self.conv3 = nn.Conv3d(32, 64, kernel_size=3, stride=2, padding=1)
        self.ln3 = nn.LayerNorm([64, 16, 19, 16])

        self.adapool = nn.AdaptiveAvgPool3d((1, 1, 1))
        self.fc = nn.Linear(64, num_classes)  # Output layer; adjust the number of outputs as necessary
        
        # Matching layers for skip connections
        self.match_conv1 = nn.Conv3d(1, 16, kernel_size=1, stride=1, padding=0)  # Matches enc_conv1 channels
        self.match_conv2 = nn.Conv3d(16, 32, kernel_size=1, stride=2, padding=0)  # Matches enc_conv2 channels and stride
        self.match_conv3 = nn.Conv3d(32, 64, kernel_size=1, stride=2, padding=0)  # Matches enc_conv3 channels and stride
        
    def forward(self, x):
        # First block with skip connection
        identity = self.match_conv1(x)  # Save input for skip connection
        x = self.conv1(x)
        x = self.ln1(x)  # Normalize
        x = self.gelu(x)
        x += identity  # Add skip connection

        # Second block with skip connection
        identity = self.match_conv2(x)  # Save input for skip connection
        x = self.conv2(x)
        x = self.ln2(x)  # Normalize
        x = self.gelu(x)
        x += identity  # Add skip connection

        # Third block with skip connection
        identity = self.match_conv3(x)  # Save input for skip connection
        x = self.conv3(x)
        x = self.ln3(x)  # Normalize
        x = self.gelu(x)
        x += identity  # Add skip connection

        # Adaptive pooling and final fully connected layer
        x = self.adapool(x)
        x = x.view(x.size(0), -1)  # Flatten the tensor
        x = self.fc(x)
        return x

In [5]:
def train_model(model, train_loader, criterion, optimizer, device, test_data=None, epochs=10):
    model.train()
    with open('training_log.txt', 'a') as log_file:
        for epoch in range(epochs):
            total_loss = 0
            num_batches = 0
            for inputs, labels in train_loader:
                labels = labels.to(device)
                inputs = inputs.to(device)
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
                num_batches += 1
            average_loss = total_loss / num_batches
            current_utc = datetime.datetime.utcnow()
            gmt8_time = current_utc + datetime.timedelta(hours=8)
            current_time = gmt8_time.strftime("%Y-%m-%d %H:%M:%S")
            log_entry = f'Epoch {epoch+1:03}, Average Loss: {average_loss}, Timestamp: {current_time}\n'
            # Write the log entry to the file
            log_file.write(log_entry)
            # Test every 10 epochs
            if (epoch + 1) % 10 == 0 and test_data is not None:
                test_input, test_label = test_data
                test_result = test_model(model, test_input, test_label, device)
                log_file.write(f"Test at Epoch {epoch+1:03}: {test_result}\n")


def test_model(model, test_input, test_label, device):
    model.eval()

    # Ensure test_input is a tensor and move it to the correct device
    if not torch.is_tensor(test_input):
        test_input = torch.tensor(test_input, dtype=torch.float, device=device)
    else:
        test_input = test_input.to(device)

    # Ensure test_label is a tensor, add a batch dimension, and move to correct device
    if isinstance(test_label, int):
        test_label = torch.tensor([test_label], dtype=torch.long, device=device)
    else:
        test_label = test_label.to(device)

    with torch.no_grad():
        # Perform model inference and get the predicted class
        outputs = model(test_input.unsqueeze(0))
        _, predicted = torch.max(outputs, 1)

        # Check if the prediction is correct
        correct = (predicted == test_label).item()  # Convert the result to Python boolean

    return "Correct" if correct else "Incorrect"

In [6]:
def main(classifier, path_prefix="", epochs=10, lr=0.001):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f'Using device: {device}')
    
    dataset_file = f'{classifier}+classify.txt'
    
    full_dataset = MRIDataset(dataset_file, path_prefix=path_prefix)
    num_classes = len(full_dataset.labels)
    grand_results = []
    
    current_utc = datetime.datetime.utcnow()
    gmt8_time = current_utc + datetime.timedelta(hours=8)
    current_time = gmt8_time.strftime("%Y-%m-%d %H:%M:%S")
    # current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    start_time = f'Start training at: {current_time}'
    print(start_time)
    
    for i in range(len(full_dataset)):
        train_indices = list(range(len(full_dataset)))
        train_indices.pop(i)  # Remove the test image index
        test_index = i
        
        train_subset = Subset(full_dataset, train_indices)
        test_input, test_label = full_dataset[test_index]

        train_loader = DataLoader(train_subset, batch_size=4, shuffle=True)

        model = ConvNeXtT(num_classes)
        model = model.to(device)
        
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr)

        train_model(model, train_loader, criterion, optimizer, device, test_data=full_dataset[test_index], epochs=epochs)
        result = test_model(model, test_input, test_label, device)
        grand_results.append(result)
        
        os.rename('training_log.txt', f'training_log+stim_{i+1:03}.txt')

    with open(f'{classifier}_final_results.log', 'w') as f:
        correct_count = grand_results.count("Correct")
        total_tests = len(grand_results)
        correct_percentage = (correct_count / total_tests) * 100 if total_tests > 0 else 0
        for idx, result in enumerate(grand_results):
            f.write(f"Model {idx+1:03}: Result: {result}\n")
        f.write(f"Percentage of Correct Predictions: {correct_percentage:.2f}%\n")

    print(f"Percentage of Correct Predictions: {correct_percentage:.2f}%")



In [7]:
##################################################
N_ep = 200

main_path = os.getcwd()
errts_path = '../../preprocess/errts'
os.chdir(errts_path)

folder_list = [folder for folder in os.listdir() if folder.startswith('s') and os.path.isdir(folder)]

os.chdir(main_path)
classifier_type = 'condition'

try:
    os.mkdir(classifier_type)
except FileExistsError:
    print(f"Folder '{classifier_type}' already exists.")

os.chdir(classifier_type)

for folder in folder_list:
    # Change to the directory
    try:
        os.mkdir(folder)
    except FileExistsError:
        print(f"Folder '{folder}' already exists.")
    os.chdir(folder)
    path_to_main = '../../'
    full_errts_path = path_to_main + errts_path + '/' + folder + '/'
    main(classifier_type, path_prefix=full_errts_path, epochs=N_ep)
    
    # Return to the parent directory
    os.chdir('..')

os.chdir(main_path)
classifier_type = 'truth_tell'

try:
    os.mkdir(classifier_type)
except FileExistsError:
    print(f"Folder '{classifier_type}' already exists.")

os.chdir(classifier_type)

for folder in folder_list:
    # Change to the directory
    os.chdir(main_path)
    try:
        os.mkdir(folder)
    except FileExistsError:
        print(f"Folder '{folder}' already exists.")
    os.chdir(folder)
    path_to_main = '../../'
    full_errts_path = path_to_main + errts_path + '/' + folder + '/'
    main(classifier_type, path_prefix=full_errts_path, epochs=N_ep)
    
    # Return to the parent directory
    os.chdir('..')

Using device: cuda
Start training at: 2024-05-13 16:48:32


KeyboardInterrupt: 