In [None]:
# mount google drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# entering the directory
%cd /content/drive/MyDrive/ICSA_DLcourse/survival/time_series

In [1]:
import zipfile
with zipfile.ZipFile("data-cd.zip", 'r') as zip_ref:
    zip_ref.extractall()

# Import packages

In [None]:
%load_ext autoreload
%autoreload 2
import os
import argparse
import time
import itertools
import random
import shutil
import numpy as np

import torch
import torch.distributed as dist
import torch.nn as nn

from config import config as cfg
from dataset_v2 import NYCDataset
from net import STFORMER
from transformers import InformerConfig, InformerModel
from transformers import AutoformerConfig, AutoformerForPrediction

from utils.misc import mkdir

from utils.logger import setup_logger
from utils.collect_env import collect_env_info
from utils import comm

from train import train, validate, RMSE, CPC, WMAPE, MAPE

In [None]:
train_dataset = NYCDataset(cfg, is_train=True)
test_dataset = NYCDataset(cfg, is_train=False)

train_dataloader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=cfg.DATALOADER.BATCH_SIZE,
        num_workers=cfg.DATALOADER.NUM_WORKERS,
        pin_memory=False,
        shuffle=False)
test_dataloader = torch.utils.data.DataLoader(
    dataset=test_dataset,
    batch_size=cfg.DATALOADER.BATCH_SIZE,
    num_workers=cfg.DATALOADER.NUM_WORKERS,
    pin_memory=False,
    shuffle=False)

In [None]:
sample = next(iter(train_dataloader))

print(sample.keys())
print("past_value", sample["dec_x"].shape,"past_time_features", sample["enc_time"].shape, "future_time_features", sample["dec_time"].shape, "future_values", sample["dec_y"].shape)

In [None]:
model = STFORMER(cfg).to(torch.float32)
optimizer = torch.optim.Adam(model.parameters(), lr=cfg.SOLVER.LR, betas=(0.5, 0.999))

# Loading from pre-trained model

In [None]:
class Args:
    def __init__(self):
        self.resume = False
        self.local_rank = 0
        self.seed = 1227
        self.print_freq = 20
        self.start_epoch = 0
args = Args()

In [None]:
resume_path = cfg.checkpoint_dir + '/model_best.pth.tar'
print(f"=> loading checkpoint {resume_path}")

checkpoint = torch.load(resume_path, map_location=torch.device('cpu'))

# Load the new state dictionary into the model
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])

# evaluate on validation set
wmape, mape, rmse, _, _ = validate(test_dataloader,model,args,cfg)