In [1]:
import numpy as np
import torch
import torch.nn as nn
# import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
import wandb

In [2]:
%load_ext autoreload
%autoreload 2
from models import VisualEncoder, VisualDecoder

In [3]:
images = np.load('data/images.npy') # путь до картинок из иглу
images = np.transpose(images,(0,3,1,2))
images = torch.Tensor(images)
train_dataset = images
train_set, val_set = torch.utils.data.random_split(train_dataset, [100000, 13873])
train_loader = data.DataLoader(train_set, batch_size=256, shuffle=True, drop_last=True, pin_memory=False, num_workers=4)
val_loader = data.DataLoader(val_set, batch_size=256, shuffle=False, drop_last=False, num_workers=4)

## Training AtariCNN

In [17]:
class Autoencoder(nn.Module):
    def __init__(self,
               encoder_class : object = VisualEncoder,
               decoder_class : object = VisualDecoder):
        super().__init__()
        self.encoder = encoder_class()
        self.decoder = decoder_class()
    def forward(self, x):
        encode = self.encoder(x)
        decode = self.decoder(encode)
        return decode
model = Autoencoder()
model.to('cuda')

Autoencoder(
  (encoder): VisualEncoder(
    (cnn): Sequential(
      (0): Conv2d(3, 32, kernel_size=(8, 8), stride=(4, 4))
      (1): ReLU()
      (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
      (3): ReLU()
      (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
      (5): ReLU()
      (6): Flatten(start_dim=1, end_dim=-1)
    )
    (linear): Sequential(
      (0): Linear(in_features=1024, out_features=512, bias=True)
      (1): ReLU()
    )
  )
  (decoder): VisualDecoder(
    (linear): Sequential(
      (0): Linear(in_features=512, out_features=1024, bias=True)
      (1): ReLU()
    )
    (cnn): Sequential(
      (0): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
      (1): ReLU()
      (2): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), output_padding=(1, 1))
      (3): ReLU()
      (4): ConvTranspose2d(32, 3, kernel_size=(8, 8), stride=(4, 4))
      (5): Tanh()
    )
  )
)

In [18]:
config_defaults = {
    'epochs': 50,
    'batch_size': 1024,
    'learning_rate': 1e-3,
    'optimizer': 'adam',
    'scheduler_step_size': 30,
    'scheduler_gamma': 0.5
}

wandb.init(project='AutoEncoder', entity='neuro_ai', name='AtariCNN', config=config_defaults)
config = wandb.config
optimizer = optim.Adam(model.parameters(), config.learning_rate)
criterion = torch.nn.MSELoss()
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 
                                            step_size=config.scheduler_step_size, 
                                            gamma=config.scheduler_gamma)

for epoch in range(config.epochs):
    #train
    model.train()
    train_loss = []
    for data in train_loader:
        data /= 255
        data = data.to('cuda')
        predict = model(data)
        loss = criterion(predict, data)    
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss.append(loss.item())
    train_loss = np.array(train_loss).mean()
    
    #evaluate
    model.eval()
    val_loss = []
    for data in val_loader:
        data /= 255
        data = data.to('cuda')
        predict = model(data)
        loss = criterion(predict, data)
        val_loss.append(loss.item())
    val_loss = np.array(val_loss).mean()
    scheduler.step()
    wandb.log({"train_loss":train_loss, 'val_loss': val_loss})

In [19]:
torch.save(model.encoder.state_dict(),'models/AtariCNN/encoder_weigths.pth')
torch.save(model.decoder.state_dict(),'models/AtariCNN/decoder_weigths.pth')