# IMPORTS

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!pip install git+https:github.com/Roodster/dsait4125-cv.git

In [None]:
import os
from types import SimpleNamespace
from datetime import datetime
import gdown


# Create dataset objects using the DspritesDataset class
from src.args import Args
from src.registry import setup
from src.dataset import DspritesDataset, get_dataloaders_2element, BinarySyntheticDataset, get_dataloaders
from src.experiment import Experiment
from src.common.utils import set_seed


# CONSTANTS

In [None]:
""" BASE PATH TO DRIVE TO STORE OUTPUTS"""
ROOT = "/content/drive/MyDrive/"

# IO

In [1]:
""" LOAD IN DATA """

def load_datasets():
    # Define the file IDs from your Google Drive shareable links
    file_ids = {
        'train.npz': '',
        'test.npz': '1hV-6Q29ixhqqrmCro8WN9WI6NxfSTEE5' 
    }

    # Create a directory to store the downloaded files
    os.makedirs('/data/2d', exist_ok=True)
    
    # Download each file from Google Drive
    for file_name, file_id in file_ids.items():
        url = f'https://drive.google.com/uc?id={file_id}'
        output = f'/content/data/2d/{file_name}.npz'
        gdown.download(url, output, quiet=False)
    
    
    train_data = DspritesDataset('/data/2d/train.npz')
    test_data = DspritesDataset('/data/2d/test.npz')
    
    return train_data, test_data

SyntaxError: expression expected after dictionary key and ':' (<ipython-input-1-52ee181a6b6a>, line 7)

In [None]:
def get_data_loaders(train_data, test_data, batch_size=32, shuffle=True, num_workers=2):
    from src.dataset import get_dataloaders_2element
    
    train_loader, test_loader = get_dataloaders_2element(
        train_data=train_data,
        test_data=test_data,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers
    )
    
    return train_loader, test_loader

In [None]:
args = SimpleNamespace(
    # Metadata
    exp_name="dev",
    model_name="maga",
    # Experiment
    seed=1,
    # File handling
    log_dir=""

    # Model parameters
    in_channels=1,
    img_size=64,
    latent_dim=10,
    
    # Dataset parameters
    train_ratio=0.7,
    test_ratio=0.3,
    val_ratio=1 - train_ratio - test_ratio,
    
    # Training parameters
    device="gpu",
    batch_size=32,
    learning_rate=0.001,
    n_epochs=3,
    
    # MAGA specific parameters
    beta_kl=1,
    beta_recon=1,
    
    # Evaluation parameters
    eval_save_model_interval=1,
    eval_interval=1,
    eval_sample_rate=1
)

current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
args.log_dir = f"{ROOT}/outputs/run_{args.exp_name}_{args.model_name}/seed_{args.seed}_{current_time}",

In [None]:


train_data, test_data = load_datasets()
train_loader, test_loader = get_data_loaders(train_data, test_data, batch_size=args.batch_size, shuffle=True, num_workers=2)
registry = setup(args.model_name)
set_seed(args.seed)

# Initialize experiment
experiment = Experiment(registry=registry, args=args)

# Run experiment
experiment.run(train_loader=train_loader, test_loader=test_loader)
    