In [2]:
import torch
import torch.nn as nn 
import numpy as np
import matplotlib.pyplot as plt 
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

In [3]:
DATASET_LEN = 28088
TRAIN_TEST_RATIO = 0.9
TRAIN_TEST_LEN = int(DATASET_LEN * TRAIN_TEST_RATIO)

NUMBER_VOXELS = 81126
DEVICE = "mps"

LEARNING_RATE = 1e-6

In [4]:
class test_fmri_dataset(Dataset):
    def __init__(self):
        super().__init__()

    def __len__(self):
        return DATASET_LEN - TRAIN_TEST_LEN
    
    def __getitem__(self, index):
        data = np.load(f"dataset/UTS01/{index + TRAIN_TEST_LEN}.npy")
        data = np.nan_to_num(data)
        return torch.tensor(data,  dtype=torch.float32).unsqueeze(0)
    
class train_fmri_dataset(Dataset):
    def __init__(self):
        super().__init__()

    def __len__(self):
        return TRAIN_TEST_LEN
    
    def __getitem__(self, index):
        data = np.load(f"dataset/UTS01/{index}.npy")
        data = np.nan_to_num(data)
        return torch.tensor(data,  dtype=torch.float32).unsqueeze(0)

In [5]:
train_dataloader = DataLoader(train_fmri_dataset(), batch_size=8, shuffle=True)
test_dataloader = DataLoader(test_fmri_dataset(), batch_size=8, shuffle=True)

In [6]:
def train_loop(model, optimiser, epochs):
    losses = []
    steps = 0

    for _ in range(epochs):
        model.train(True)
        for d, data in enumerate(train_dataloader):
            inputs = data.to(DEVICE)

            optimiser.zero_grad()
            loss, pred = model(inputs)

            print(loss)
            loss.backward()
            optimiser.step()

            if steps%10 == 0:
                data = next(iter(test_dataloader)).to("cpu")
                loss, pred = model(data)
                
                print("test loss ---------------------------------")
                print(loss)
                losses.append(loss.item())

            steps += 1
    return losses   


In [7]:
class MAE_Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(NUMBER_VOXELS, NUMBER_VOXELS)
    
    def forward(self, x):
        output = self.layer1(x)
        loss = nn.L1Loss(x, x)
        
        return loss, output

In [8]:
model = MAE_Model()
optimiser = torch.optim.Adam(model.parameters(), LEARNING_RATE)

In [14]:
print(f"model total paramters: {sum(p.numel() for p in model.parameters()):,}")

model total paramters: 6,581,509,002


In [None]:
losses = train_loop(model, optimiser, 1)