In [1]:
import pandas as pd
import os
from project.dataset import Dataset, VALDODataset
from project.preprocessing import z_score_normalization, min_max_normalization, NiftiToTensorTransform
# from project.preprocessing import z_score_normalization, min_max_normalization
from project.training import split_train_val_datasets
from project.utils import collate_fn, plot_all_slices, plot_all_slices_from_array
from torch.utils.data import DataLoader
import torch
from project.model import VisionTransformer
from torch.optim import Adam
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from project.model.rpn_to_gcvit import RPN_to_GCVIT


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
D_model = 256
N_classes = 2
Img_size = (16,16)
Patch_size = (16, 16)
N_channels = 1
N_heads = 8
N_layers = 3
batch_size = 1
epochs = 3
Alpha = 1e-5

In [3]:
ds = Dataset()

In [4]:
masks = ds.load_cmb_masks()
cases = ds.load_skullstripped_mri()

In [5]:
transform = NiftiToTensorTransform(target_shape=(512,512))

In [6]:
dataset = VALDODataset(
    cases=cases, 
    masks=masks, 
    transform=transform,
    normalization=z_score_normalization
)

In [7]:
has_cmb = [1 if count > 0 else 0 for count in dataset.cmb_counts]

df_dataset = pd.DataFrame({
    'MRI Scans': dataset.cases,
    'Segmented Masks': dataset.masks,
    'CMB Count': dataset.cmb_counts,
    'Has CMB': has_cmb
})

In [8]:
train_dataset, val_dataset = split_train_val_datasets(
    df=df_dataset, 
    transform=transform
)

In [9]:
train_loader = DataLoader(
    train_dataset, 
    shuffle=True, 
    batch_size=batch_size,
    collate_fn=collate_fn
)

test_loader = DataLoader(
    val_dataset, 
    shuffle=False, 
    batch_size=batch_size,
    collate_fn=collate_fn
)

In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [11]:
transformer = VisionTransformer(
    D_model=D_model,
    N_classes=N_classes,
    Img_size=Img_size,
    Patch_size=Patch_size,
    N_channels=N_channels,
    N_heads=N_heads,
    N_layers=N_layers,
    device=device
)

In [12]:
transformer = transformer.to(device)
connector = RPN_to_GCVIT()

In [13]:
optimizer = Adam(transformer.parameters(), lr=Alpha)
criterion = nn.CrossEntropyLoss()

In [14]:
from matplotlib import pyplot as plt
from tqdm.auto import tqdm
loss_history = []

for epoch in range(epochs):
    training_loss = 0.0
    progress_bar = tqdm(
        enumerate(train_loader, 0),
        total=len(train_loader),
        desc=f'Epoch {epoch + 1}/{epochs} loss: {training_loss / len(train_loader):.3f}'
    )

    epoch_loss_history = []
    
    for i, data in progress_bar:

        cropped_images = connector.get_cropped_locations(img=data[0], x_min=160, y_min=324, x_max=176, y_max=340).float().to(device)
        cropped_labels = connector.get_cropped_locations(img=data[1], x_min=160, y_min=324, x_max=176, y_max=340).float().to(device)

        num_slices = cropped_images.size(2)
        print(num_slices)
        for j in range(num_slices):

            print('-----------------------')
            print(f'Learning case {i} slice {j}')
            
            optimizer.zero_grad()
            outputs = transformer(cropped_images, cropped_labels, current_slice=j)
            # outputs = F.interpolate(
            #     outputs,
            #     size=img_size,
            #     mode='bilinear',
            #     align_corners=False
            # )
            
            if cropped_labels.max() >= N_classes:
                cropped_labels = torch.clamp(cropped_labels, 0, N_classes-1)

            if epoch == 2:
                for i in range(outputs.shape[0]):  # Assuming outputs[0] is the batch dimension
                    img = outputs[0][i].cpu().detach().numpy()  # Move to CPU and convert to NumPy
                    plt.imshow(img.squeeze(), cmap='gray')  # Squeeze to remove single-dimensional entries
                    plt.title(f'Slice {i}')
                    plt.show()

            print(outputs.shape)
            batch, channel, num_slices, height, width = cropped_images.shape
            cropped_labels = cropped_labels.view(batch_size, height * num_slices, width)
            # cropped_labels = cropped_labels.long()
            loss = criterion(outputs, cropped_labels.long())
            epoch_loss_history.append(loss)
            loss.backward()
            optimizer.step()
            training_loss += loss.item()
        
    loss_history.append(epoch_loss_history)
    print(f'Epoch {epoch + 1}/{epochs} loss: {training_loss:.3f}')

Epoch 1/3 loss: 0.000:   0%|          | 0/57 [00:00<?, ?it/s]

49
-----------------------
Learning case 0 slice 0
Shape before patch: torch.Size([1, 1, 49, 16, 16])
before torch.Size([1, 1, 49, 16, 16])
after view torch.Size([1, 49, 256])
after project torch.Size([1, 49, 256])


Epoch 1/3 loss: 0.000:   0%|          | 0/57 [00:02<?, ?it/s]

Shape after patch: torch.Size([1, 49, 256])
Shape after positional: torch.Size([1, 50, 256])
Shape after transformer: torch.Size([1, 50, 256])
torch.Size([])





RuntimeError: shape '[1, 1, 256]' is invalid for input of size 1

In [16]:
import winsound

# Frequency (Hz) and duration (ms)
frequency = 1000  # Set frequency to 1000 Hz
duration = 500    # Set duration to 500 ms

# Play the sound
winsound.Beep(frequency, duration)