In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

import numpy as np

from models.autoencoder import Autoencoder
from dataset_extracted import ExtractedFeatureDatasetOnlyFeature
import os
import time

In [2]:
latent = 512
lr = 0.0001
BATCH = 1
EPOCH = 100

In [3]:
feature_path = '../MVAD/I3D_rgb_kinetics/train'
with open('../MVAD/train_fine') as f:
    files = f.readlines()
    feature_files = list(map(lambda file: os.path.join(feature_path, str.strip(file) + '.npy'), files))
    feature_files.extend(list(map(lambda file: os.path.join(feature_path + '_fliped', str.strip(file) + '.npy'), files)))
dataset = ExtractedFeatureDatasetOnlyFeature(None ,feature_files=feature_files)
dataloader = DataLoader(dataset, BATCH, True)

In [4]:
device = torch.device('cuda:0')

In [5]:
model = Autoencoder(512)
model.to(device)
model.train()

Autoencoder(
  (encoder_linear1): Linear(in_features=1024, out_features=512, bias=True)
  (decoder_linear1): Linear(in_features=512, out_features=1024, bias=True)
)

In [6]:
optimizer = optim.Adam(model.parameters(), lr)
criterion = nn.L1Loss()

In [7]:
start_time = time.time()
for epoch in range(1, EPOCH + 1):
    train_loss = 0
    train_accuracy = 0
    for batch, feature in enumerate(dataloader, start=1):
        feature = feature.to(device)
        out = model(feature)
        loss = criterion(out, feature)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        accuracy = 0
        train_loss += loss
        train_accuracy += accuracy
        
    print('epoch {} {:.1f}s - train_loss: {:.6f} train_accuracy: {:.4f}'.format(epoch, time.time() - start_time, train_loss / batch, train_accuracy / batch))
    if epoch % 10:
        torch.save({'epoch': epoch,
                    'model_state_dict': model.state_dict(),}, '../checkpoint/autoencoder{}'.format(epoch // 10))

epoch 1 157.4s - train_loss: 0.037921 train_accuracy: 0.0000
epoch 2 313.8s - train_loss: 0.029148 train_accuracy: 0.0000
epoch 3 468.1s - train_loss: 0.028370 train_accuracy: 0.0000
epoch 4 623.8s - train_loss: 0.028019 train_accuracy: 0.0000
epoch 5 779.5s - train_loss: 0.027819 train_accuracy: 0.0000
epoch 6 935.5s - train_loss: 0.027684 train_accuracy: 0.0000
epoch 7 1097.0s - train_loss: 0.027592 train_accuracy: 0.0000
epoch 8 1254.6s - train_loss: 0.027512 train_accuracy: 0.0000
epoch 9 1413.4s - train_loss: 0.027455 train_accuracy: 0.0000
epoch 10 1571.9s - train_loss: 0.027413 train_accuracy: 0.0000
epoch 11 1729.2s - train_loss: 0.027370 train_accuracy: 0.0000
epoch 12 1881.9s - train_loss: 0.027334 train_accuracy: 0.0000
epoch 13 2041.0s - train_loss: 0.027303 train_accuracy: 0.0000
epoch 14 2198.5s - train_loss: 0.027277 train_accuracy: 0.0000
epoch 15 2352.4s - train_loss: 0.027259 train_accuracy: 0.0000
epoch 16 2504.2s - train_loss: 0.027226 train_accuracy: 0.0000
epoch 1

KeyboardInterrupt: 