In [1]:
import numpy as np
import pandas as pd


import torch
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import Dataset, DataLoader

from utils import load_test_data, visualize_test_region, generate_training_samples
from dataset import STDataset
from baseline import RandomRegionBaseline, TissueSpecificRandomRegionBaseline
from evaluate import Evaluator

from parameter import create_args
from models.common import get_linear_scheduler
from models.vae_gaussian import GaussianVAE
from models.vae_flow import FlowVAE


In [2]:
training_samples = generate_training_samples(num_samples_per_slice=3)

Seeding all randomness with seed=2024


In [3]:
dataset = STDataset(training_samples)

In [4]:
args = create_args()

In [5]:
if args.model == 'gaussian':
    model = GaussianVAE(args).to(args.device)
elif args.model == 'flow':
    model = FlowVAE(args).to(args.device)

In [6]:
model

GaussianVAE(
  (encoder): PointNetEncoder(
    (conv1): Conv1d(376, 128, kernel_size=(1,), stride=(1,))
    (conv2): Conv1d(128, 128, kernel_size=(1,), stride=(1,))
    (conv3): Conv1d(128, 256, kernel_size=(1,), stride=(1,))
    (conv4): Conv1d(256, 512, kernel_size=(1,), stride=(1,))
    (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bn3): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (bn4): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (fc1_m): Linear(in_features=512, out_features=256, bias=True)
    (fc2_m): Linear(in_features=256, out_features=128, bias=True)
    (fc3_m): Linear(in_features=128, out_features=256, bias=True)
    (fc_bn1_m): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (fc_bn2_m): BatchNorm1d(128, eps=1e-05, moment

In [7]:
# Optimizer and scheduler
optimizer = torch.optim.Adam(model.parameters(), 
    lr=args.lr, 
    weight_decay=args.weight_decay
)
scheduler = get_linear_scheduler(
    optimizer,
    start_epoch=args.sched_start_epoch,
    end_epoch=args.sched_end_epoch,
    start_lr=args.lr,
    end_lr=args.end_lr
)

In [8]:
# Define a DataLoader to handle batching
dataloader = DataLoader(dataset, batch_size=10, shuffle=True, num_workers=4)

# Example of iterating over the DataLoader
for batch in dataloader:
    positions = batch['positions']
    expressions = batch['expressions']
    metadata = batch['metadata']
    
    # Use the positions, expressions for model training
    # The metadata could be used for logging, tracking, or conditioning if needed
    print(positions.shape, expressions.shape)
    
    x = torch.cat((positions, expressions), dim=2).to(args.device)
    
    # Reset grad and model state
    optimizer.zero_grad()
    model.train()
    
    loss = model.get_loss(x)

    # Backward and optimize
    loss.backward()
    orig_grad_norm = clip_grad_norm_(model.parameters(), args.max_grad_norm)
    optimizer.step()
    scheduler.step()

torch.Size([3, 50, 2]) torch.Size([3, 50, 374])


In [9]:
loss

tensor(60033.6172, device='cuda:0', grad_fn=<AddBackward0>)