In [1]:
from datetime import datetime, time, timedelta
import matplotlib.pyplot as plt
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import xarray as xr
from ocf_blosc2 import Blosc2
from torch.utils.data import DataLoader, IterableDataset
from torchinfo import summary
import json
plt.rcParams["figure.figsize"] = (20, 12)
# %load_ext autoreload
# %autoreload 2

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


## Train a model

In [3]:
from dataset import HDF5Dataset
dataset = HDF5Dataset("./data/ds14_processed_data/processed_train.hdf5", True, True, True, True)
data_loader = DataLoader(dataset, batch_size=16, pin_memory=True, num_workers=8, shuffle=True)
print(f"train dataset len: {len(dataset)}")

train dataset len: 56184


In [4]:
EPOCHS = 200
from submission.model import OurResnet2
model = OurResnet2(image_size=128).to(device)
criterion = nn.L1Loss()
optimiser = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.05)
lr_scheduler = optim.lr_scheduler.OneCycleLR(optimiser, max_lr=1e-2, epochs=EPOCHS, steps_per_epoch=len(data_loader))
summary(model, input_size=[(1, 12), (1, 1, 12, 128, 128), (1, 10, 6, 128, 128), (1, 4)])
# x = torch.randn((1, 12)).to(device)
# y = torch.randn((1, 1, 12, 128, 128)).to(device)
# z = torch.randn((1, 10, 6, 128, 128)).to(device)
# model(x, y, z)

Layer (type:depth-idx)                             Output Shape              Param #
OurResnet2                                         [1, 48]                   --
├─VideoResNet: 1-1                                 [1, 256]                  --
│    └─R2Plus1dStem: 2-1                           [1, 64, 12, 64, 64]       --
│    │    └─Conv3d: 3-1                            [1, 45, 12, 64, 64]       2,205
│    │    └─BatchNorm3d: 3-2                       [1, 45, 12, 64, 64]       90
│    │    └─ReLU: 3-3                              [1, 45, 12, 64, 64]       --
│    │    └─Conv3d: 3-4                            [1, 64, 12, 64, 64]       8,640
│    │    └─BatchNorm3d: 3-5                       [1, 64, 12, 64, 64]       128
│    │    └─ReLU: 3-6                              [1, 64, 12, 64, 64]       --
│    └─Sequential: 2-2                             [1, 64, 12, 64, 64]       --
│    │    └─BasicBlock: 3-7                        [1, 64, 12, 64, 64]       222,016
│    │    └─BasicBlock:

In [5]:
START_EPOCH = 0
MODEL_KEY="Extra_TemporalResnet2+1Combo-Full-Weather-NewOptim-FC"
print(f"Training model key {MODEL_KEY}")
from tqdm import tqdm
for epoch in range(EPOCHS):
    model.train()

    running_loss = 0.0
    count = 0
    for i, (pv_features, hrv_features, nwp, extra, pv_targets) in (pbar := tqdm(enumerate(data_loader), total=len(data_loader))):
        optimiser.zero_grad()
        with torch.autocast(device_type=device):
            hrv_features = torch.unsqueeze(hrv_features, 1)
            predictions = model(
                pv_features.to(device,dtype=torch.float),
                hrv_features.to(device,dtype=torch.float),
                nwp.to(device,dtype=torch.float),
                extra.to(device,dtype=torch.float),
            )
            loss = criterion(predictions, pv_targets.to(device, dtype=torch.float))
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
        optimiser.step()

        size = int(pv_targets.size(0))
        running_loss += float(loss) * size
        count += size

        if i % 10 == 9:
            pbar.set_description(f"Epoch {START_EPOCH + epoch + 1}, {i + 1}: {running_loss / count}")
        if i % 100 == 99:
            print(f"Epoch {START_EPOCH + epoch + 1}, {i + 1}: {running_loss / count}")

        lr_scheduler.step()

    print(f"Epoch {START_EPOCH + epoch + 1}: {running_loss / count}")
    torch.save(model.state_dict(), f"/data/{MODEL_KEY}-ep{START_EPOCH + epoch + 1}.pt")
    print("Saved model!")

Training model key Extra_TemporalResnet2+1Combo-Full-Weather-NewOptim-FC


Epoch 1, 60: 0.5330028793464104:   2%|▏         | 65/3512 [00:47<41:39,  1.38it/s]


KeyboardInterrupt: 

In [None]:
# Save your model
# torch.save(model.state_dict(), "submission/model.pt")