# Training VAE for RGB Images

### IMPORTS

In [None]:
import torch
from torch.optim import Adam
from torch.utils.data import DataLoader
from utils.loaders import FeaturesDataset

from models import FC_VAE
from train_vae import train


### SETUP

In [None]:
BATCH_SIZE = 32
EPOCHS = 50
LR = 0.001
MOMENTUM = 0.9
WEIGHT_DECAY = 1e-4
STEP_SIZE = 10
GAMMA = 0.1

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
if torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
    print("------ USING APPLE SILICON GPU ------")

features_file = "saved_features/saved_feat_I3D_25_dense_D1"

### TRAINING

In [None]:
train_dataset = FeaturesDataset(features_file,'train')
train_loader_rgb = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True)

model = FC_VAE(dim_input=1024, nz=64)
model.to(DEVICE)

# Create Optimizer & Scheduler objects
optimizer = Adam(model.parameters(), lr=LR, betas=(0.9, 0.98), eps=1e-9)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=STEP_SIZE, gamma=GAMMA)

train(model, optimizer, EPOCHS, DEVICE, train_loader_rgb, train_loader_rgb, BATCH_SIZE, scheduler)

torch.save(model.state_dict(), f'./saved_models/VAE_RGB/final_VAE_RGB_epoch_{EPOCHS}.pth')
