In [2]:
!git clone https://github.com/moshesipper/vae-torch-celeba.git

Cloning into 'vae-torch-celeba'...
remote: Enumerating objects: 50, done.[K
remote: Counting objects: 100% (36/36), done.[K
remote: Compressing objects: 100% (24/24), done.[K
remote: Total 50 (delta 22), reused 26 (delta 12), pack-reused 14 (from 1)[K
Receiving objects: 100% (50/50), 28.72 MiB | 46.10 MiB/s, done.
Resolving deltas: 100% (23/23), done.


In [3]:
%cd vae-torch-celeba

/kaggle/working/vae-torch-celeba


In [4]:
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import pandas as pd
import os
import numpy as np
from torch.utils.data import Subset
from torchvision.utils import save_image

# Constants from your VAE model
IMAGE_SIZE = 150
LATENT_DIM = 128
image_dim = 3 * IMAGE_SIZE * IMAGE_SIZE

# Set seed for reproducibility
def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    
set_seed(42)

class CelebADataset(Dataset):
    def __init__(self, root_dir, partition_file, attr_file, transform=None):
        self.root_dir = root_dir
        self.img_dir = os.path.join(root_dir, 'img_align_celeba', 'img_align_celeba')
        self.partition_df = pd.read_csv(partition_file)
        self.attr_df = pd.read_csv(attr_file)
        self.transform = transform
        
    def __len__(self):
        return len(self.partition_df)
    
    def __getitem__(self, idx):
        img_name = os.path.join(self.img_dir, self.partition_df.iloc[idx, 0])
        image = Image.open(img_name)
        
        if self.transform:
            image = self.transform(image)
            
        # Flatten the image to match VAE requirements
        return image.view(-1)

def get_celeba_dataloaders(root_dir, batch_size=64, subset_fraction=0.1):
    # Define transforms according to VAE specifications
    transform = transforms.Compose([
        transforms.Resize(IMAGE_SIZE, antialias=True),
        transforms.CenterCrop(IMAGE_SIZE),
        transforms.ToTensor()
    ])
    
    # Create dataset
    full_dataset = CelebADataset(
        root_dir=root_dir,
        partition_file=os.path.join(root_dir, 'list_eval_partition.csv'),
        attr_file=os.path.join(root_dir, 'list_attr_celeba.csv'),
        transform=transform
    )
    
    # Get indices for train/val/test
    partition_df = pd.read_csv(os.path.join(root_dir, 'list_eval_partition.csv'))
    train_indices = partition_df[partition_df['partition'] == 0].index.tolist()
    val_indices = partition_df[partition_df['partition'] == 1].index.tolist()
    test_indices = partition_df[partition_df['partition'] == 2].index.tolist()
    
    # Take subset if requested
    num_train_samples = int(len(train_indices) * subset_fraction)
    train_indices = train_indices[:num_train_samples]
    
    # Create subset datasets
    train_dataset = Subset(full_dataset, train_indices)
    val_dataset = Subset(full_dataset, val_indices)
    test_dataset = Subset(full_dataset, test_indices)
    
    print(f"Number of training samples: {len(train_dataset)}")
    print(f"Number of validation samples: {len(val_dataset)}")
    print(f"Number of test samples: {len(test_dataset)}")
    
    # Create dataloaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )
    
    return train_loader, val_loader, test_loader


In [5]:
pd.set_option('future.no_silent_downcasting', True)

In [6]:
def generate_reconstructions(model, test_loader, device, num_images=8):
    model.eval()
    model = model.to(device)  # Move model to GPU
    
    with torch.no_grad():
        # Get one batch
        for imgs in test_loader:
            pics = imgs.to(device)  # Move images to GPU
            # Reshape for the model
            pics = pics.view(-1, 3, IMAGE_SIZE, IMAGE_SIZE)
            break
            
        # Keep first num_images
        pics = pics[:num_images]
        orig = torch.clone(pics)
        
        # Generate reconstructions
        all_pics = [orig]
        current_pics = orig
        for _ in range(7):
            recon, _, _ = model(current_pics)
            recon = recon.view(-1, 3, IMAGE_SIZE, IMAGE_SIZE)
            all_pics.append(recon)
            current_pics = recon
            
        # Stack all images
        final_pics = torch.cat(all_pics, dim=0)
        
        # Move back to CPU for saving
        final_pics = final_pics.cpu()
        
        # Save the grid
        save_image(
            final_pics, 
            'reconstructions2.jpg', 
            nrow=num_images, 
            normalize=True
        )

# Usage
root_dir = '/kaggle/input/celeba-dataset'
train_loader, val_loader, test_loader = get_celeba_dataloaders(
    root_dir=root_dir,
    batch_size=64,
    subset_fraction=1.0
)

# Set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load model
MODEL_FILE = 'vae_model_20.pth'
model = torch.load(MODEL_FILE, map_location=device, weights_only=False)

# Generate reconstructions
generate_reconstructions(model, test_loader, device)

Number of training samples: 162770
Number of validation samples: 19867
Number of test samples: 19962
Using device: cuda

CELEB_PATH ./data/ IMAGE_SIZE 150 LATENT_DIM 128 image_dim 67500


In [7]:
!ls

LICENSE    __pycache__	reconstructions2.jpg  utils.py	vae_model_20.pth
README.md  genpics.py	trainvae.py	      vae.py


In [8]:
!ls
%pwd

LICENSE    __pycache__	reconstructions2.jpg  utils.py	vae_model_20.pth
README.md  genpics.py	trainvae.py	      vae.py


'/kaggle/working/vae-torch-celeba'

In [9]:
attr = pd.read_csv("/kaggle/input/celeba-dataset/list_attr_celeba.csv")

In [10]:
import torch as th
import torch
from torch import nn
from torchvision import transforms
from torch import nn
from torch.nn import functional as F
import torch as th
import nibabel as nib
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from torchvision import transforms
from PIL import Image
import gc
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from tqdm import tqdm 

class FmriConvNet(nn.Module):
    def __init__(self):
        super(FmriConvNet, self).__init__()

        # Convolutional Block 1
        self.conv1 = nn.Conv3d(in_channels=1, out_channels=8, kernel_size=3, stride=2, padding=1)
        self.relu1 = nn.ReLU()

        # Convolutional Block 2
        self.conv2 = nn.Conv3d(in_channels=8, out_channels=16, kernel_size=3, stride=2, padding=1)
        self.relu2 = nn.ReLU()

        # Convolutional Block 3
        self.conv3 = nn.Conv3d(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1)
        self.relu3 = nn.ReLU()

        # Convolutional Block 4
        self.conv4 = nn.Conv3d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1)
        self.relu4 = nn.ReLU()

        # Convolutional Block 5
        self.conv5 = nn.Conv3d(in_channels=64, out_channels=128, kernel_size=(5,5,3), stride=1, padding=0)
        self.relu5 = nn.ReLU()

        # Flatten and Fully Connected Layers
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(128 * 1 * 1 * 1, 1024)  # Reduced input size
        self.relu_fc1 = nn.ReLU()
        self.output = nn.Linear(1024, 1024)

        self.mu = nn.Linear(1024, 128)
        self.logvar = nn.Linear(1024, 128)

    def forward(self, x):
        x = self.relu1(self.conv1(x))  # Output shape: (batch_size, 8, 40, 40, 21)
        x = self.relu2(self.conv2(x))  # Output shape: (batch_size, 16, 20, 20, 11)
        x = self.relu3(self.conv3(x))  # Output shape: (batch_size, 32, 10, 10, 6)
        x = self.relu4(self.conv4(x))  # Output shape: (batch_size, 64, 5, 5, 3)

        x = self.relu5(self.conv5(x))  # Output shape: (batch_size, 128, 1, 1, 1)

        x = self.flatten(x)            # Output shape: (batch_size, 128)
        x = self.relu_fc1(self.fc1(x)) # Output shape: (batch_size, 1024)
        x = self.output(x)             # Output shape: (batch_size, 1024)

        mu = self.mu(x)
        logvar = self.logvar(x)
        return mu, logvar
brain2latent = FmriConvNet().to(device)
# br2l_path = "/kaggle/input/brain2latent/pytorch/default/1/brain2latent_final3.pth"
# brain2latent.load_state_dict(th.load(br2l_path)["model_state_dict"])

# with th.no_grad():
#     output = brain2latent(th.randn(5, 1, 80, 80, 41).float().to(device))
#     print(output[0].shape)

transform = transforms.Compose([
        transforms.Resize(IMAGE_SIZE, antialias=True),
        transforms.CenterCrop(IMAGE_SIZE),
        transforms.ToTensor()
    ])
def load_and_transform_image(image_path):
    """
    Loads an image, applies transformations, and returns the transformed image tensor
    in the shape (1, 3, 128, 128).
    """
    image = Image.open(image_path)
    transformed_image = transform(image)
    return transformed_image
header = "/kaggle/input/ds-sub2-download/ds001761-download/"
all_events=[]
for i in range(1,3):
    for j in range(1,9):
        for k in range(1,9):
            stub = header+f"sub-0{i}/ses-0{j}/func/sub-0{i}_ses-0{j}_task-faces_run-0{k}"
            event_path = stub +"_events.tsv"
            fmri_path = stub +"_bold.nii.gz"
            event_obj = pd.read_csv(event_path,delimiter='\t')
            event_obj["stim_file"]= "/kaggle/input/ds-sub2-download/ds001761-download/stimuli/" + event_obj["stim_file"]
            event_obj["fmri_map"] = (event_obj["onset"]/2).round()
            event_obj["subject_number"] = i
            event_obj["session_number"] = j
            event_obj["run_number"] = k
            all_events.append(event_obj)
            
event_set = pd.concat(all_events)
event_set = event_set[event_set["stim_file"] != "/kaggle/input/ds-sub2-download/ds001761-download/stimuli/placeholders/fixation.png"]
img2name = "/kaggle/input/images2celeba-txt/ImageNames2Celeba.txt"
img_df = pd.read_csv(img2name,delimiter='\t',names=["stim_file","image_id"])
img_df["stim_file"] = "/kaggle/input/ds-sub2-download/ds001761-download/stimuli/" + img_df["stim_file"]
event_set = pd.merge(event_set,img_df,on="stim_file",how="inner")
print(device)

cuda


In [11]:
event_set = pd.merge(event_set,attr,on="image_id",how="inner")

In [12]:
attribute_list =  ['5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 'Bags_Under_Eyes',
       'Bald', 'Bangs', 'Big_Lips', 'Big_Nose', 'Black_Hair', 'Blond_Hair',
       'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 'Double_Chin',
       'Eyeglasses', 'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones',
       'Male', 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes', 'No_Beard',
       'Oval_Face', 'Pale_Skin', 'Pointy_Nose', 'Receding_Hairline',
       'Rosy_Cheeks', 'Sideburns', 'Smiling', 'Straight_Hair', 'Wavy_Hair',
       'Wearing_Earrings', 'Wearing_Hat', 'Wearing_Lipstick',
       'Wearing_Necklace', 'Wearing_Necktie', 'Young']
print(len(attribute_list))

40


In [13]:
import pandas as pd
from collections import defaultdict

def get_class(attr_vals):
    hash=0
    pow=1
    for val in attr_vals:
        hash += val*pow
        pow*=2
    return hash


important_attr = ["Male"]
event_set["class"] = 0
class_batches = defaultdict(list)
for i, row in event_set.iterrows():
    attributes = row[important_attr].fillna(0).astype(int).values
    hashval = get_class(attributes)
    event_set.at[i, "class"] = hashval
    class_batches[hashval].append(i)
 #print(event_set.iloc[i]["class"],hashval)

In [14]:
class_counts = event_set["class"].value_counts()
subset = class_counts

In [15]:
print(min(subset))

6119


In [16]:
# from torch.data.utils import Sampler
# ## gonna make a dataset that returns the minibatch curated according to contrastive loss

# class contr_fmri(Sampler):
#     def __init__(self,classcounts,batch_indices,batch_size):
#         #self.event_set = event_set
#         #self.transform=transform
#         #self.attribute_list = attribute_list
#         self.classcounts = classcounts
#         self.class_hashes = classcounts.index.tolist()
#         self.batch_indices=batch_indices
#     def __len__(self):
#         return(min(self.classcounts))
#     def __iter__(self):
#         examples = defaultdict(list)
#         for(hash in self.class_hashes):
#             ## append indices of pandas dataframe
#             examples[hash].append(self.batch_indices[hash][idx:idx+batch_size])
            
            
        
    

In [17]:
!wget https://github.com/ANTsX/ANTsPy/releases/download/v0.5.4/antspyx-0.5.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

--2024-12-02 12:07:35--  https://github.com/ANTsX/ANTsPy/releases/download/v0.5.4/antspyx-0.5.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Resolving github.com (github.com)... 140.82.116.4
Connecting to github.com (github.com)|140.82.116.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/101671766/8d3c7d1d-41b3-451f-9791-5eb21f45a5ef?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=releaseassetproduction%2F20241202%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20241202T120736Z&X-Amz-Expires=300&X-Amz-Signature=4a78e72549ec6b41236bab90a1ef7f420572afdd50670d7aa769d4577360e2be&X-Amz-SignedHeaders=host&response-content-disposition=attachment%3B%20filename%3Dantspyx-0.5.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl&response-content-type=application%2Foctet-stream [following]
--2024-12-02 12:07:36--  https://objects.githubusercontent.com/github-production-re

In [18]:
! pip install ./antspyx-0.5.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

Processing ./antspyx-0.5.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Installing collected packages: antspyx
Successfully installed antspyx-0.5.4


In [19]:
# Initialize models and move to GPU
vae_gan = model
# mod = th.load(model_path, map_location='cuda')
# vae_gan.load_state_dict(mod["model_state_dict"])

def print_gpu_debug_info():
    print("CUDA available:", th.cuda.is_available())
    print("Current device:", th.cuda.current_device() if th.cuda.is_available() else "CPU")
    print("Device count:", th.cuda.device_count() if th.cuda.is_available() else 0)
    print("Device name:", th.cuda.get_device_name(0) if th.cuda.is_available() else "CPU")
    
    # Add memory usage info
    if th.cuda.is_available():
        print("\nGPU Memory Usage:")
        print(f"Allocated: {th.cuda.memory_allocated(0) / 1024**2:.2f} MB")
        print(f"Cached: {th.cuda.memory_reserved(0) / 1024**2:.2f} MB")

In [20]:
from torch.utils.data import Sampler
from collections import defaultdict
import numpy as np

class ContrastiveBatchSampler(Sampler):
    def __init__(self, labels):
        self.labels = labels
        self.label_to_indices = defaultdict(list)
        for idx, label in enumerate(self.labels):
            self.label_to_indices[label].append(idx)
        self.labels_set = list(self.label_to_indices.keys())
        self.num_samples = len(self.labels)
        self.batch_size = 4  # Each batch will have 4 samples (two pairs)

    def __iter__(self):
        # Compute the number of batches
        num_batches = len(self) 
        # Shuffle the class labels
        np.random.shuffle(self.labels_set)
        # Start generating batches
        for _ in range(num_batches):
            # Randomly select two classes
            if len(self.labels_set) < 2:
                # Re-shuffle if not enough classes
                np.random.shuffle(self.labels_set)
            classes = np.random.choice(self.labels_set, size=2, replace=False)
            batch_indices = []
            for cls in classes:
                # Randomly select two indices from this class
                indices = self.label_to_indices[cls]
                if len(indices) >= 2:
                    selected_indices = np.random.choice(indices, size=2, replace=False)
                else:
                    # If not enough samples, replicate indices
                    selected_indices = np.random.choice(indices, size=2, replace=True)
                batch_indices.extend(selected_indices)
            yield batch_indices

    def __len__(self):
        # Calculate the number of batches
        return self.num_samples // self.batch_size


In [21]:
from nilearn.datasets import load_mni152_template
import ants
# Load the MNI template (default resolution: 2mm)
mni_template = load_mni152_template()
template_path = mni_template.get_filename()
template_img_ants = ants.from_numpy(mni_template.get_fdata())

In [22]:
header_path = "/kaggle/input/ds-sub2-download/ds001761-download/"
from torch.utils.data import Dataset

class FmriDataSet(Dataset):
    def __init__(self, event_set, transform=None):
        self.event_set = event_set
        self.transform = transform
        self.attribute_list = attribute_list
        self.labels = event_set['class'].values  # Ensure 'class_label' is the correct column name

    def __len__(self):
        return len(self.event_set)

    def __getitem__(self, idx):
        row = self.event_set.iloc[idx]
        subject_num = row["subject_number"]
        session_num = row["session_number"]
        run_num = row["run_number"]
        img_path = row["stim_file"]
        fmri_map = row["fmri_map"].astype(int)
        #attribute_values = row[self.attribute_list].fillna(0).astype(int).values
        #attribute_set = th.tensor(attribute_values)#
        gender = th.tensor(row["class"])
        if self.transform:
            img = self.transform(img_path)
        fmri_path = f"/kaggle/input/ds-sub2-download/ds001761-download/sub-0{subject_num}/ses-0{session_num}/func/sub-0{subject_num}_ses-0{session_num}_task-faces_run-0{run_num}_bold.nii.gz"
        anat_path = f"/kaggle/input/ds-sub2-download/ds001761-download/sub-0{subject_num}/ses-0{session_num}/anat/sub-0{subject_num}_ses-0{session_num}_acq-01_T1w.nii.gz"
        run_fmri = nib.load(fmri_path).get_fdata()[:, :, :, fmri_map]
        #anat_img = nib.load(anat_path).get_fdata()
        ## takes too long for now.
        # Normalize the image
        # fmri_img_ants = ants.from_numpy(run_fmri)
        # anat_img_ants = ants.from_numpy(anat_img)
        # anat2mni_transform = ants.registration(
        #     fixed=template_img_ants,
        #     moving=anat_img_ants,
        #     type_of_transform='SyN',
        # )
        # fmri2mni = ants.apply_transforms(
        #     fixed=template_img_ants,
        #     moving=fmri_img_ants,
        #     transformlist=anat2mni_transform['fwdtransforms']
        # )
        #fmri_output = th.from_numpy(fmri2mni.numpy())
        fmri_output = th.from_numpy(run_fmri)
        fmri_output = fmri_output.unsqueeze(0)
        return img, run_fmri, gender


In [23]:
# header_path = "/kaggle/input/ds-sub2-download/ds001761-download/"
# from torch.utils.data import Dataset
# class FmriDataSet(Dataset):
#     def __init__(self,event_set,transform=None):
#         self.event_set = event_set
#         self.transform=transform
#         self.attribute_list = attribute_list
#     def __len__(self):
#         return(len(self.event_set))
#     def __getitem__(self,idx):
#         row = self.event_set.iloc[idx]
#         subject_num = row["subject_number"]
#         session_num = row["session_number"]
#         run_num = row["run_number"]
#         img_path = row["stim_file"]
#         fmri_map = row["fmri_map"].astype(int)
#         attribute_values = row[self.attribute_list].fillna(0).astype(int).values
#         attribute_set = th.tensor(attribute_values)
#         if self.transform:
#             img = self.transform(img_path)
#         fmri_path = f"/kaggle/input/ds-sub2-download/ds001761-download/sub-0{subject_num}/ses-0{session_num}/func/sub-0{subject_num}_ses-0{session_num}_task-faces_run-0{run_num}_bold.nii.gz"
#         anat_path = f"/kaggle/input/ds-sub2-download/ds001761-download/sub-0{subject_num}/ses-0{session_num}/anat/sub-0{subject_num}_ses-0{session_num}_acq-01_T1w.nii.gz"
#         run_fmri  = nib.load(fmri_path).get_fdata()[:,:,:,fmri_map]
#         anat_img  =nib.load(anat_path)
#         ## now im going to normalize said image
#         antspy_fmri = antspy.from_numpy(run_fmri)
        
#         anat2mni_transform = ants.registration(
#             fixed=template_img_ants,
#             moving=anat_img_ants,
#             type_of_transform='SyN',
#         )
#         fmri2mni = ants.apply_transforms(
#             fixed=template_img_ants,
#             moving=fmri_img_ants,
#             transformlist=anat2mni_transform['fwdtransforms']
#         )

#         fmri_output = th.from_numpy(fmri2mni)      
#         fmri_output = fmri_output.unsqueeze(0)
#         return(img,run_fmri,attribute_set)

In [24]:
train_events,test_events = train_test_split(event_set,test_size=0.2)
train_events,val_events = train_test_split(event_set,test_size=0.2)

In [25]:
train_fmri_dataset = FmriDataSet(train_events,load_and_transform_image)
test_fmri_dataset =  FmriDataSet(test_events,load_and_transform_image)
val_fmri_dataset = FmriDataSet(val_events,load_and_transform_image)

In [26]:
batch_sampler = ContrastiveBatchSampler(train_fmri_dataset.labels)

In [27]:
train_dataloader = DataLoader(train_fmri_dataset, batch_sampler=batch_sampler)
test_dataloader = DataLoader(test_fmri_dataset, batch_size=4)
val_dataloader = DataLoader(test_fmri_dataset, batch_size=8)

In [28]:
# progress_bar = tqdm(enumerate(test_dataloader), desc=f"Epoch {epoch+1}", total=len(test_dataloader), leave=False)

for i, data in enumerate(train_dataloader):
    img_batch, fmri_data_batch,attribute_set = data
    print(img_batch.shape, fmri_data_batch.shape)
    print(attribute_set,attribute_set.shape)
    break

torch.Size([4, 3, 150, 150]) torch.Size([4, 80, 80, 41])
tensor([-1, -1,  1,  1]) torch.Size([4])


In [29]:
!mkdir /kaggle/working/models/

In [30]:
loss_criterion= nn.CrossEntropyLoss()
def pair_contrastive_loss(e1,e2,c1,c2):
    # fmri_embeddings will be a (2,fmri_shape) vector , the classes vector correpond to both
    loss = (c1 == c2)*(loss_criterion(e1,e2)) + (c1!=c2)*(max(0,loss_criterion(e1,e2)))
    return loss

In [None]:
num_epochs = 3
optimizer = th.optim.Adam(brain2latent.parameters(), lr=1e-3)

for epoch in range(num_epochs):
    construction_loss_value = 0
    epoch_loss = 0
    progress_bar = tqdm(enumerate(train_dataloader), desc=f"Epoch {epoch+1}", total=len(train_dataloader), leave=False)

    for i, data in progress_bar:
        img_batch, fmri_data_batch,classes = data
        img_batch = img_batch.view(-1, 3, IMAGE_SIZE, IMAGE_SIZE)  # Reshape if needed

        # Forward pass
        with th.no_grad():
            encoded_mu, encoded_logvar = vae_gan.encode(img_batch.to(device, non_blocking=True))

        fmri_data_batch = fmri_data_batch.reshape(4,1,80,80,41)
        #print(fmri_data_batch.shape)
        latent_mu, latent_logvar = brain2latent(fmri_data_batch.float().to(device, non_blocking=True))
        
        # Compute MSE loss
        construction_loss = F.mse_loss(encoded_mu, latent_mu)
        construction_loss += F.mse_loss(encoded_logvar, latent_logvar)
        construction_loss_value = construction_loss.item()
        for i in range(4):
            #print(encoded_mu[i].shape)
            for j in range(i+1,4):
                construction_loss += pair_contrastive_loss(encoded_mu[i],encoded_mu[j],classes[i],classes[j] ) 

        
        # Backward pass
        optimizer.zero_grad()
        construction_loss.backward()
        optimizer.step()

        # Update loss
        epoch_loss += construction_loss.item()
        
        if (i % 100 == 0) and i != 0:
            print_gpu_debug_info()
            
        progress_bar.set_description(f"Epoch {epoch+1}, TrainLoss: {construction_loss_value:.4f}")
        th.cuda.empty_cache()
        gc.collect()

    save_path = f'/kaggle/working/models/brain2latent_epoch{epochs+1}.pth'
    th.save({
        'epoch': num_epochs,
        'model_state_dict': brain2latent.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': epoch_loss,
    }, save_path)
    
    
    print(f"Epoch {epoch+1} completed, Average Loss: {epoch_loss / len(train_dataloader):.4f}")

Epoch 1, TrainLoss: 1.1037:  38%|███▊      | 922/2457 [1:09:32<1:54:17,  4.47s/it]

In [None]:
import matplotlib.pyplot as plt
from tqdm import tqdm

# Assuming IMAGE_SIZE is defined elsewhere
IMAGE_SIZE = 150

brain2latent.eval()
vae_gan.eval()
avg_test_loss = 0
epoch = 0
count = 0

progress_bar = tqdm(enumerate(test_dataloader), desc=f"Epoch {epoch+1}", total=len(test_dataloader), leave=False)

for i, data in progress_bar:
    img_batch, fmri_data_batch = data
    test_loss = 0

    # Forward pass
    with th.no_grad():
        # Encode fMRI data and images
        encoded_mu, encoded_logvar = vae_gan.encode(img_batch.to(device, non_blocking=True))
        latent_mu, latent_logvar = brain2latent(fmri_data_batch.float().to(device, non_blocking=True))
        print(encoded_mu.shape,encoded_logvar.shape)
        # Compute MSE loss
        test_loss = F.mse_loss(encoded_mu, latent_mu)
        test_loss += F.mse_loss(encoded_logvar, latent_logvar)
        test_lossvalue = test_loss.item()
        print(f"Test Loss: {test_lossvalue:.4f}")
        
        avg_test_loss += test_loss

        # Decode generated images from latent space
        img = vae_gan.decode(vae_gan.reparameterize(latent_mu, latent_logvar))
        
        # Generate the reconstruction of the original image batch
        vaegan_imgs = vae_gan.decode(vae_gan.reparameterize(encoded_mu, encoded_logvar))

        # Convert tensors to CPU numpy arrays for visualization
        img = img.cpu().detach().numpy()
        print(vaegan_imgs)
        vaegan_imgs = vaegan_imgs.cpu().detach().numpy()
        img_batch = img_batch.cpu().detach().numpy()

        # Plot original, reconstructed, and generated images
        fig, axes = plt.subplots(4, 3, figsize=(12, 10))  # 4 rows for 4 images, 3 columns
        for j in range(4):  # Plot first 4 images in the batch
            # Original Image
            axes[j, 0].imshow(np.clip(img_batch[j].transpose(1, 2, 0), 0, 1))
            axes[j, 0].axis("off")
            axes[j, 0].set_title("Original Image")

            # Reconstructed Image
            axes[j, 1].imshow(np.clip(vaegan_imgs[j].transpose(1, 2, 0), 0, 1))
            axes[j, 1].axis("off")
            axes[j, 1].set_title("Reconstructed Image")

            # Generated Image
            axes[j, 2].imshow(np.clip(img[j].transpose(1, 2, 0), 0, 1))
            axes[j, 2].axis("off")
            axes[j, 2].set_title("Generated Image")

        plt.tight_layout()
        plt.show()

    break  # Process only the first batch for visualization
    count += 1
    if count == 10:  # Limit to 10 iterations
        break

    progress_bar.set_description(f"Epoch {epoch+1}, TestLoss: {test_lossvalue:.4f}")

# Compute the average test loss
avg_test_loss = avg_test_loss / (i + 1)
print(f"Average Test Loss: {avg_test_loss:.4f}")


In [None]:
!ls /kaggle/working/vae-torch-celeba/ds001761-download/

In [None]:
# Save at the end of training
model_save_path = 'brain2latent_final.pth'
th.save({
    'epoch': num_epochs,
    'model_state_dict': brain2latent.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': epoch_loss,
}, model_save_path)



# To load the model later:
def load_model(path):
    checkpoint = th.load(path)
    brain2latent.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    return checkpoint['epoch'], checkpoint['loss']

In [None]:
from IPython.display import FileLink
FileLink('./brain2latent_final.pth')

In [None]:
import shutil

shutil.make_archive('zipped_file_name', 'zip', '/vae-torch-celeba/brain2latent_final.pth')
from IPython.display import FileLink
FileLink(r'zipped_file_name.zip')

In [None]:
# Save at the end of training
model_save_path = 'brain2latent_final3.pth'
th.save({
    'model_state_dict': brain2latent.state_dict(),
}, model_save_path)