In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import argparse
import zarr
from torchinfo import summary

import sys
# sys.path.insert(0, '/home/paperspace/Al/ClimateHack-2024/code/')


from models.multimodal import MultimodalModel
from models.compressor import CompressorModel
from models.attention import AttentionModel
from dataloader import ChDataModule
from train import train_epoch, eval_epoch, train_loop

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [2]:
argparser = argparse.ArgumentParser()
argparser.add_argument("--model", type=str, default="multimodal")
argparser.add_argument("--name", type=str, default="0")
argparser.add_argument("--pretrain", type=bool, default=False)
argparser.add_argument("--use_hrv", type=bool, default=True)
argparser.add_argument("--use_weather", type=bool, default=False)
argparser.add_argument("--use_metadata", type=bool, default=True)
argparser.add_argument("--use_pv", type=bool, default=True)
argparser.add_argument("--epochs", type=int, default=80)
argparser.add_argument("--add_epochs", type=int, default=0)
argparser.add_argument("--batch_size", type=int, default=32)
argparser.add_argument("--lr", type=float, default=1e-3)
argparser.add_argument("--weight_decay", type=float, default=0.00)
argparser.add_argument("--dropout", type=float, default=0.0)
argparser.add_argument("--batchnorm", type=bool, default=True)
argparser.add_argument("--checkpoint", type=str, default=None)
argparser.add_argument("--data_dir", type=str, default="data")
argparser.add_argument("--dataloader_cfg", type=dict, default={"num_workers": 8, "batch_size": 8, "pin_memory": True, "persistent_workers": True})
argparser.add_argument("--datamodule_cfg", type=dict, default={"val_split": 0.1, "cache_dir": "data/cache/"})
argparser.add_argument("--freeze", type=bool, default=False)
argparser.add_argument("--train", type=bool, default=True)
argparser.add_argument("-f", type=str, default="")

args = argparser.parse_args()

In [3]:
epochs = args.epochs

datamodule = ChDataModule(args.datamodule_cfg, args.dataloader_cfg)
datamodule.setup('fit')
train_loader = datamodule.train_dataloader()
val_loader = datamodule.val_dataloader()
criterion = nn.L1Loss()

Fetching 25 files:   0%|          | 0/25 [00:00<?, ?it/s]

In [4]:
if not args.pretrain:
    datamodule.toggle_train_hrv()
    datamodule.toggle_val_hrv()

use_hrv set to True
use_hrv set to True


In [5]:
if args.model == "multimodal":
    model = MultimodalModel(args)
    summary(model)

elif args.model == "compressor":
    model = CompressorModel(args)
    summary(model)

elif args.model == "attention":
    model = AttentionModel(args)
    summary(model)

else:
    raise ValueError("Invalid model")

model.to(device)

MultimodalModel(
  (conv1): Conv2d(12, 24, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(24, 48, kernel_size=(3, 3), stride=(1, 1))
  (conv3): Conv2d(48, 96, kernel_size=(3, 3), stride=(1, 1))
  (conv4): Conv2d(96, 192, kernel_size=(3, 3), stride=(1, 1))
  (conv5): Conv2d(192, 96, kernel_size=(3, 3), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (relu): ReLU()
  (dropout): Dropout(p=0.0, inplace=False)
  (sigmoid): Sigmoid()
  (batchnorm): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (linear0): Linear(in_features=19, out_features=256, bias=True)
  (linear1): Linear(in_features=256, out_features=512, bias=True)
  (linear2): Linear(in_features=512, out_features=1024, bias=True)
  (linear3): Linear(in_features=1024, out_features=512, bias=True)
  (linear4): Linear(in_features=512, out_features=48, bias=True)
  (linear5): Linear(in_features=384,

In [6]:
if args.checkpoint is not None:
    model.load_state_dict(torch.load(args.checkpoint))

if args.freeze:
    model.freeze_pretrain()

if args.train:
    train_loop(model, args, train_loader, val_loader)

eval_epoch(model, args, criterion, val_loader)


Epoch 1
Train Loss: 0.19668770325027804


NameError: name 'np' is not defined