In [1]:
##### this is for classification in single subject

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
import numpy as np
import os
import nibabel as nib
import random
import datetime
from torch.utils.tensorboard import SummaryWriter
import logging
import torch.optim as optim
from sklearn.metrics import precision_recall_fscore_support

2024-06-14 00:34:55.003265: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-06-14 00:34:55.003298: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-06-14 00:34:55.003709: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-06-14 00:34:55.006528: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
class MRIDataset(Dataset):
    def __init__(self, file_path, path_prefix=""):
        self.path_prefix = path_prefix
        # full_file_path = path_prefix + file_path
        full_file_path = 'condition+classify.txt'
        with open(full_file_path, 'r') as file:
            data = [line.strip().split() for line in file.readlines()]
        
        self.cond_label = {cond_label: idx for idx, cond_label in enumerate(set(row[1] for row in data))}
        self.files = [(row[0], self.cond_label[row[1]]) for row in data]
        random.shuffle(self.files)
    
    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img_path, cond_label = self.files[idx]
        full_img_path = self.path_prefix + img_path
        
        img = nib.load(full_img_path).get_fdata()
        img = np.float32(img)
        img = torch.from_numpy(img)
        if img.ndim == 4 and img.shape[-1] == 1:
            img = img.squeeze(-1)
        img = img.unsqueeze(0)
        cond_label = torch.tensor(cond_label, dtype=torch.long)
        return img, cond_label

In [4]:
class BasicBlock3D(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(BasicBlock3D, self).__init__()
        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm3d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm3d(out_channels)
        
        self.downsample = None
        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm3d(out_channels)
            )

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        
        if self.downsample is not None:
            identity = self.downsample(x)
        
        out += identity
        out = self.relu(out)
        return out

In [5]:
class ResNet3D(nn.Module):
    def __init__(self, block, layers, num_classes=2):
        super(ResNet3D, self).__init__()
        self.in_channels = 64
        self.conv1 = nn.Conv3d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm3d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool3d((1, 1, 1))
        self.fc = nn.Linear(512, num_classes)

    def _make_layer(self, block, out_channels, blocks, stride=1):
        layers = []
        layers.append(block(self.in_channels, out_channels, stride))
        self.in_channels = out_channels
        for _ in range(1, blocks):
            layers.append(block(out_channels, out_channels))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

In [6]:
def resnet18_3d(num_classes=10):
    return ResNet3D(BasicBlock3D, [2, 2, 2, 2], num_classes)

In [7]:
def train_model(model, train_loader, criterion, optimizer, device, test_data=None, epochs=10):
    model.train()
    with open('training_log.txt', 'a') as log_file:
        for epoch in range(epochs):
            total_loss = 0
            num_batches = 0
            for inputs, labels in train_loader:
                labels = labels.to(device)
                inputs = inputs.to(device)
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
                num_batches += 1
            average_loss = total_loss / num_batches
            current_utc = datetime.datetime.utcnow()
            gmt8_time = current_utc + datetime.timedelta(hours=8)
            current_time = gmt8_time.strftime("%Y-%m-%d %H:%M:%S")
            log_entry = f'Epoch {epoch+1:03}, Average Loss: {average_loss}, Timestamp: {current_time}\n'
            # Write the log entry to the file
            log_file.write(log_entry)
            # Test every 10 epochs
            if (epoch + 1) % 10 == 0 and test_data is not None:
                test_input, test_label = test_data
                test_result = test_model(model, test_input, test_label, device)
                log_file.write(f"Test at Epoch {epoch+1:03}: {test_result}\n")


def test_model(model, test_input, test_label, device):
    model.eval()

    # Ensure test_input is a tensor and move it to the correct device
    if not torch.is_tensor(test_input):
        test_input = torch.tensor(test_input, dtype=torch.float, device=device)
    else:
        test_input = test_input.to(device)

    # Ensure test_label is a tensor, add a batch dimension, and move to correct device
    if isinstance(test_label, int):
        test_label = torch.tensor([test_label], dtype=torch.long, device=device)
    else:
        test_label = test_label.to(device)

    with torch.no_grad():
        # Perform model inference and get the predicted class
        outputs = model(test_input.unsqueeze(0))
        _, predicted = torch.max(outputs, 1)

        # Check if the prediction is correct
        correct = (predicted == test_label).item()  # Convert the result to Python boolean

    return "Correct" if correct else "Incorrect"

In [8]:
def main(classifier, path_prefix="", epochs=10, lr=0.001, batch=8):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f'Using device: {device}')
    
    dataset_file = f'{classifier}+classify.txt'
    
    full_dataset = MRIDataset(dataset_file, path_prefix=path_prefix)
    num_classes = len(full_dataset.cond_label)
    grand_results = []
    
    current_utc = datetime.datetime.utcnow()
    gmt8_time = current_utc + datetime.timedelta(hours=8)
    current_time = gmt8_time.strftime("%Y-%m-%d %H:%M:%S")
    # current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    start_time = f'Start training at: {current_time}'
    print(start_time)
    
    for i in range(len(full_dataset)):
        train_indices = list(range(len(full_dataset)))
        train_indices.pop(i)  # Remove the test image index
        test_index = i
        
        train_subset = Subset(full_dataset, train_indices)
        test_input, test_label = full_dataset[test_index]

        train_loader = DataLoader(train_subset, batch_size=batch, shuffle=True)
            
        model = resnet18_3d(num_classes)
        model = model.to(device)
        
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=lr)

        train_model(model, train_loader, criterion, optimizer, device, test_data=full_dataset[test_index], epochs=epochs)
        result = test_model(model, test_input, test_label, device)
        grand_results.append(result)
        
        os.rename('training_log.txt', f'training_log+stim_{i+1:03}.txt')

    with open(f'{classifier}_final_results.log', 'w') as f:
        correct_count = grand_results.count("Correct")
        total_tests = len(grand_results)
        correct_percentage = (correct_count / total_tests) * 100 if total_tests > 0 else 0
        for idx, result in enumerate(grand_results):
            f.write(f"Model {idx+1:03}: Result: {result}\n")
        f.write(f"Percentage of Correct Predictions: {correct_percentage:.2f}%\n")

    print(f"Percentage of Correct Predictions: {correct_percentage:.2f}%")



In [9]:
##################################################
N_ep = 100
N_batch = 8

main_path = os.getcwd()
errts_path = '../../preprocess/errts'
classifier_type = 'within_comp'
os.chdir(classifier_type)

folder_list = [folder for folder in os.listdir() if folder.startswith('s') and os.path.isdir(folder)]

for folder in folder_list:
    os.chdir(folder)
    path_to_main = '../../'
    full_errts_path = path_to_main + errts_path + '/' + folder + '/'
    main(classifier_type, path_prefix=full_errts_path, epochs=N_ep, batch=N_batch)
    
    # Return to the parent directory
    os.chdir('..')


Using device: cuda
Start training at: 2024-06-14 08:34:56
Percentage of Correct Predictions: 61.70%
Using device: cuda
Start training at: 2024-06-14 11:51:31
Percentage of Correct Predictions: 58.97%
Using device: cuda
Start training at: 2024-06-14 14:03:10
Percentage of Correct Predictions: 81.58%
Using device: cuda
Start training at: 2024-06-14 16:07:18
Percentage of Correct Predictions: 78.05%


In [10]:
os.getcwd()

'/mnt/1122_DL/Final/models/ResNet/within_comp'

In [11]:
classifier_type

'within_comp'

In [12]:
os.chdir(classifier_type)

FileNotFoundError: [Errno 2] No such file or directory: 'within_comp'