In [9]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from scipy.io import loadmat
import os
from pathlib import Path

# Custom dataset class for loading .mat files
class MatFileDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.file_paths = []
        
        # Get all .mat files from subdirectories
        for folder in range(1, 181):
            folder_path = os.path.join(root_dir, str(folder))
            if os.path.exists(folder_path):
                mat_files = [f for f in os.listdir(folder_path) if f.endswith('.mat')]
                self.file_paths.extend([os.path.join(folder_path, f) for f in mat_files])
    
    def __len__(self):
        return len(self.file_paths)
    
    def __getitem__(self, idx):
        mat_data = loadmat(self.file_paths[idx])
        # Adjust the key below based on how your data is stored in the .mat file
        data = mat_data['new_brain']  
        # Add channel dimension and convert to torch tensor
        data = torch.FloatTensor(data).unsqueeze(0)
        return data

In [19]:
class Autoencoder3D(nn.Module):
    def __init__(self, latent_dim=256):
        super(Autoencoder3D, self).__init__()
        
        # Create individual layers instead of Sequential for debugging
        # Encoder layers
        self.conv1 = nn.Conv3d(1, 16, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv3d(16, 32, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv3d(32, 64, kernel_size=3, stride=2, padding=1)
        self.relu = nn.ReLU()
        
        # We'll calculate the flattened size dynamically
        self.flatten = nn.Flatten()
        self.enc_linear = None  # We'll set this after calculating the size
        
        # Decoder layers
        self.dec_linear = None  # We'll set this after calculating the size
        self.dec_unflatten = None  # We'll set this after calculating the size
        self.deconv1 = nn.ConvTranspose3d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.deconv2 = nn.ConvTranspose3d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.deconv3 = nn.ConvTranspose3d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.sigmoid = nn.Sigmoid()
        
        # Initialize with a test forward pass
        self._init_layers(latent_dim)
    
    def _init_layers(self, latent_dim):
        # Run a test input through the convolutional layers to get the correct size
        x = torch.randn(1, 1, 108, 90, 108)
        print(f"Initial input shape: {x.shape}")
        
        x = self.relu(self.conv1(x))
        print(f"After conv1: {x.shape}")
        
        x = self.relu(self.conv2(x))
        print(f"After conv2: {x.shape}")
        
        x = self.relu(self.conv3(x))
        print(f"After conv3: {x.shape}")
        
        x = self.flatten(x)
        print(f"After flatten: {x.shape}")
        
        flattened_size = x.shape[1]
        self.enc_linear = nn.Linear(flattened_size, latent_dim)
        self.dec_linear = nn.Linear(latent_dim, flattened_size)
        
        # Calculate the 3D shape after conv3
        self.final_shape = (64, x.shape[1]//(64))  # Channels, flattened spatial dimensions
        print(f"Final shape for unflattening: {self.final_shape}")
        
    def forward(self, x):
        # Encoder
        print(f"\nForward pass input shape: {x.shape}")
        x = self.relu(self.conv1(x))
        print(f"After conv1: {x.shape}")
        
        x = self.relu(self.conv2(x))
        print(f"After conv2: {x.shape}")
        
        x = self.relu(self.conv3(x))
        print(f"After conv3: {x.shape}")
        
        x = self.flatten(x)
        print(f"After flatten: {x.shape}")
        
        x = self.enc_linear(x)
        print(f"After encoder linear: {x.shape}")
        
        # Decoder
        x = self.dec_linear(x)
        print(f"After decoder linear: {x.shape}")
        
        # Reshape back to 3D
        batch_size = x.shape[0]
        x = x.view(batch_size, 64, -1)  # We'll calculate the proper reshape here
        print(f"After reshaping: {x.shape}")
        
        x = self.relu(self.deconv1(x))
        print(f"After deconv1: {x.shape}")
        
        x = self.relu(self.deconv2(x))
        print(f"After deconv2: {x.shape}")
        
        x = self.sigmoid(self.deconv3(x))
        print(f"Final output: {x.shape}")
        
        return x

# Test the model
model = Autoencoder3D()
test_input = torch.randn(1, 1, 108, 90, 108)
output = model(test_input)

Initial input shape: torch.Size([1, 1, 108, 90, 108])
After conv1: torch.Size([1, 16, 54, 45, 54])
After conv2: torch.Size([1, 32, 27, 23, 27])
After conv3: torch.Size([1, 64, 14, 12, 14])
After flatten: torch.Size([1, 150528])
Final shape for unflattening: (64, 2352)

Forward pass input shape: torch.Size([1, 1, 108, 90, 108])
After conv1: torch.Size([1, 16, 54, 45, 54])
After conv2: torch.Size([1, 32, 27, 23, 27])
After conv3: torch.Size([1, 64, 14, 12, 14])
After flatten: torch.Size([1, 150528])
After encoder linear: torch.Size([1, 256])
After decoder linear: torch.Size([1, 150528])
After reshaping: torch.Size([1, 64, 2352])


RuntimeError: Expected 4D (unbatched) or 5D (batched) input to conv_transpose3d, but got input of size: [1, 64, 2352]

In [17]:
# Initialize dataset and dataloader
dataset = MatFileDataset('../mat_files')
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

# Initialize model, loss function, and optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Autoencoder3D().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Training loop
num_epochs = 100
for epoch in range(num_epochs):
    total_loss = 0
    for batch in dataloader:
        # Move batch to device
        batch = batch.to(device)
        
        # Forward pass
        output = model(batch)
        loss = criterion(output, batch)
        
        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    # Print epoch statistics
    avg_loss = total_loss / len(dataloader)
    print(f'Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.6f}')

# Save the model
torch.save(model.state_dict(), 'autoencoder3d_model.pth')

RuntimeError: mat1 and mat2 shapes cannot be multiplied (8x129024 and 150528x256)