In [None]:
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import random
import time
from dataset import *
from model import FeatureNetC

In [None]:
class Args:
    def __init__(self):
        self.gpu = 0
        self.run_label = 0
        self.learning_rate = 1e-4
        self.feat_size = 128
        self.model_name = 'c'
        
args = Args()
device = torch.device('cuda:%d' % args.gpu)
args.device = device

In [None]:
while True:
    args.log_dir = '/data/hdim-forecast/log3/rep/model=%s, feat_size=%d-lr=%.5f-run=%d' % \
        (args.model_name, args.feat_size, args.learning_rate, args.run_label)
    if not os.path.isdir(args.log_dir):
        os.makedirs(args.log_dir)
        break
    args.run_label += 1
print("Run number = %d" % args.run_label)
writer = SummaryWriter(args.log_dir)
log_writer = open(os.path.join(args.log_dir, 'results.txt'), 'w')

start_time = time.time()
global_iteration = 0
random.seed(args.run_label)  # Set a different random seed for different run labels
torch.manual_seed(args.run_label)
    
def log_scalar(name, value, epoch):
    writer.add_scalar(name, value, epoch)
    log_writer.write('%f ' % value)

In [None]:
train_dataset = MovingMNIST(train=True, n_past=1, n_future=1, deterministic=False)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
print(len(train_dataset))

In [None]:
def preprocess(bx, by, bl, device):
    bx = bx.view(-1, 1, 64, 64).to(device)
    by = F.one_hot(by.type(torch.long), num_classes=10).sum(axis=1, keepdim=True).repeat(1, bl.shape[1], 1).view(-1, 10).to(device).type(torch.float32)
    bl = torch.stack([bl.sum(dim=2), (bl[:, :, 0] - bl[:, :, 1]).abs()], axis=2).view(-1, 4).to(device)
    by[by > 1.0] = 1.0
    return bx, by, bl

In [None]:
model = FeatureNetC(args.feat_size, train_mode=True).to(device)
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)

In [None]:
# bx, by, bl = iter(train_loader).next()
# bx, by, bl = preprocess(bx, by, bl, device)

for epoch in range(3000):
    for data in train_loader:
        optimizer.zero_grad()
        bx, by, bl = preprocess(*data, device)
        feat, output = model(bx)
        label = output[:, :10]
        loc = output[:, 10:14]
        loss_label = F.binary_cross_entropy_with_logits(input=label, target=by, weight=by+0.3333)
        # loss_label = (label - by).pow(2).sum(axis=1).mean()
        loss_loc = (loc - bl).pow(2).sum(axis=1).mean()
        
        loss_reg = (feat.std(dim=0) - 1.0).pow(2).mean()
        loss_all = loss_label + loss_loc + loss_reg
        
        binarized = (F.sigmoid(label) > 0.5).type(torch.float32)
        recall = (binarized * by).mean() / by.mean()
    
    
        writer.add_scalar('loss_label', loss_label, global_iteration)
        writer.add_scalar('loss_loc', loss_loc, global_iteration)
        writer.add_scalar('loss_reg', loss_reg, global_iteration)
        writer.add_scalar('recall', recall, global_iteration)
        global_iteration += 1
        loss_all.backward()
        optimizer.step()
    if epoch % 100 == 0:
        print("Epoch %d, time %.2f" % (epoch, time.time() - start_time))
        torch.save(model.state_dict(), 'pretrained/representation-%s-%d.pt' % (args.model_name, args.feat_size))