In [None]:
from collections import OrderedDict
import sys
from pathlib import Path
import json
import time
import math
import csv
from datetime import datetime
import random
from random import randint
import itertools
import torchvision.utils as vutils

import gc

import numpy as np
import matplotlib.pyplot as plt
import sklearn as skl
from sklearn.manifold import TSNE
import sklearn.cluster

import cv2

import kornia

import torch
import torch.nn as nn

import torchvision as tv
from torchvision.models import resnet50, ResNet50_Weights, resnet34
from torchvision.transforms import functional as tfunc

sys.path.append('../mine_soar')
import MalmoPython
from utils import draw_image, draw_layers, MinecraftHandler

MISSION_PORT = 9001
VIDEO_SHAPE = (128, 128)
VIDEO_DEPTH = 3
DEVICE = 'cuda:0'
# DEVICE = 'cpu'

mc_handler = MinecraftHandler(VIDEO_SHAPE)

def start_mission(file="random_world.xml", seed=b''):
    mc_handler.start_mission(file, seed)

def get_mc_img(show=False):
    image = mc_handler.get_image()
    
    if show:
        draw_image(image, show=True)
    return image

In [None]:
class BottleneckBlock(nn.Module):
    def __init__(self, in_dims, mid_dims, out_dims, padding="zeros", downscale=False, upscale=False, non_linearity=nn.ReLU):
        super().__init__()
        
        self.downscale = downscale
        self.upscale = upscale
        
        if downscale:
            self.downscale_layer = nn.Sequential(
                nn.Conv2d(in_dims, in_dims, 3, 2, 1, padding_mode=padding),
                nn.BatchNorm2d(in_dims),
                non_linearity()
            )
            
        self.dim_compress = nn.Sequential(
            nn.Conv2d(in_dims, mid_dims, 1, 1, 0),
            nn.BatchNorm2d(mid_dims),
            non_linearity()
        )
        
        self.bottleneck = nn.Sequential(
            nn.Conv2d(mid_dims, mid_dims, 5, 1, 2, padding_mode=padding),
            nn.BatchNorm2d(mid_dims),
            non_linearity()
        )
        
        self.dim_extend = nn.Sequential(
            nn.Conv2d(mid_dims, out_dims, 1, 1, 0),
            nn.BatchNorm2d(out_dims),
            non_linearity()
        )
        
        if upscale:
            self.upscale_layer = nn.Sequential(
                nn.ConvTranspose2d(out_dims, out_dims, 2, 2, 0),
                nn.BatchNorm2d(out_dims),
                non_linearity()
            )
        
    def forward(self, x):
        if self.downscale:
            x = self.downscale_layer(x)
        x = self.dim_compress(x)
        x = self.bottleneck(x)
        x = self.dim_extend(x)
        if self.upscale:
            x = self.upscale_layer(x)
        return x

In [None]:
class TestModel(nn.Module):
    def __init__(self, in_channels=3, features=[[64,32,128], [128, 64, 256]], downscales=[True, False],
                 padding="zeros", non_linearity=nn.ReLU, final_nonlin=nn.Sigmoid,
                 loss_fn=nn.MSELoss(), opt_method=torch.optim.SGD, lr=1e-3):
        super().__init__()
        
        self.intake = nn.Sequential(
            nn.Conv2d(in_channels, features[0][0], 1, 1, 0, padding_mode=padding),
            nn.BatchNorm2d(features[0][0]),
            non_linearity()
        )
        
        compress_layers = []
        for i, num_feats in enumerate(features):
            if isinstance(num_feats, int):
                in_feats = mid_feats = out_feats = num_feats
            else:
                in_feats = num_feats[0]
                mid_feats = num_feats[1]
                out_feats = num_feats[2]
            layer = BottleneckBlock(in_feats, mid_feats, out_feats, 
                                    padding=padding, non_linearity=non_linearity, 
                                    downscale=downscales[i])
            compress_layers.append(layer)
        self.compress = nn.Sequential(*compress_layers)
        
        decompress_layers = []
        for i, num_feats in enumerate(reversed(features)):
            if isinstance(num_feats, int):
                in_feats = mid_feats = out_feats = num_feats
            else:
                in_feats = num_feats[2]
                mid_feats = num_feats[1]
                out_feats = num_feats[0]
            layer = BottleneckBlock(in_feats, mid_feats, out_feats, 
                                    padding=padding, non_linearity=non_linearity, 
                                    upscale=downscales[-i])
            decompress_layers.append(layer)
        self.decompress = nn.Sequential(*decompress_layers)
        
        self.output = nn.Sequential(
            nn.Conv2d(features[0][0], in_channels, 1, 1, 0, padding_mode=padding),
            nn.BatchNorm2d(in_channels),
            final_nonlin()
        )
        
        self.in_channels=in_channels
        self.features = features
        self.downscales = downscales
        self.padding = padding
        self.non_linearity = non_linearity
        self.final_nonlin = final_nonlin
        self.loss_fn = loss_fn
        self.opt_method = opt_method
        self.lr = lr
        
        
    def forward(self, x):
        y = self.intake(x)
        y = self.compress(y)
        y = self.decompress(y)
        y = self.output(y)
        return y

In [None]:
def gen_test_models(in_channels=[3], 
                    features=[[[64,32,128], [128,64,256]]], 
                    downscales=[True, False],
                    padding=["replicate"],
                    non_linearity=[nn.LeakyReLU],
                    final_nonlin=[nn.Sigmoid],
                    loss_fn=[nn.MSELoss()], 
                    opt_method=[torch.optim.SGD], 
                    lr=[1e-3]):
    hyperparam_product = itertools.product(in_channels, features, downscales, padding, non_linearity, final_nonlin, loss_fn, opt_method, lr)
    
    for hparamset in hyperparam_product:
        yield TestModel(*hparamset).to(DEVICE).train()

In [None]:
def test_model(models, iters=10001, show_step=250, save_step=250, reset_step=250, seed=2):
    random.seed(seed)
    
    model_optims = [m.opt_method(m.parameters(), lr=m.lr) for m in models]

    if not mc_handler.is_running:
        start_mission(seed=random.randbytes(16))

    session_id = f"{datetime.strftime(datetime.now(), '%Y-%m-%d_%H-%M-%S')}"
    save_root = Path("./runs/"+session_id+"/")
    save_root.mkdir()

    losses_record = []
    images_record = []

    try:
        for i in range(iters):
            mc_handler.sendCommand(f"setYaw {(i*(359/reset_step))%360}")
            try:
                img_tens = get_mc_img(False)
                img_tens = img_tens.to(DEVICE)
            except Exception as e:
                print(e)
                start_mission()
                continue

            models_out = [m(img_tens) for m in models]
            losses = [models[i].loss_fn(img_tens, models_out[i]) for i in range(len(models))]
            for loss in losses:
                loss.backward()
            for model_optim in model_optims:
                model_optim.step()
                model_optim.zero_grad()

            losses_record.append([loss.item() for loss in losses])
            if i % reset_step == 1:
                print(f"{i}: {[loss.item() for loss in losses]}")
    #             image = torch.concat([img_tens, *models_out], dim=3)
                images = [img_tens, *models_out]
                images = [i.squeeze() for i in images]
                grid_image = vutils.make_grid(images, nrow=int((len(models)+1)**0.5), padding=0, pad_value=0.5, normalize=True).cpu()
                draw_image(grid_image, show=True)
                images_record.append([i.detach().cpu().squeeze() for i in images])

            if i % show_step == 0:
                print(f"{i}: {[loss.item() for loss in losses]}")
    #             image = torch.concat([img_tens, *models_out], dim=3)
                images = [img_tens, *models_out]
                images = [i.squeeze() for i in images]
                grid_image = vutils.make_grid(images, nrow=int((len(models)+1)**0.5), padding=0, pad_value=0.5, normalize=True).cpu()
                draw_image(grid_image, show=True)
                images_record.append([i.detach().cpu().squeeze() for i in images])
            if reset_step is not None and i % reset_step == 0 and i > 0 and i != iters-1:
                mc_handler.sendCommand("quit")
                start_mission(seed=random.randbytes(16))
    except KeyboardInterrupt as e:
        print(e)

    with open(save_root.joinpath("losses.csv"), "w") as csv_file:
        for i, losses in enumerate(losses_record):
            csv_file.write(f"{i}{','.join([str(l) for l in losses])},\n")
    for i, imgs in enumerate(images_record):
        for j, img in enumerate(imgs):
            draw_image(img, save_root.joinpath(f"images_{i*show_step}_{j}.png"))
    plt.figure(figsize=(16, 8))
    plt.plot(losses_record)
    plt.savefig(save_root.joinpath("losses.png"))
    plt.show()
    plt.close()
    for i, model in enumerate(models):
        with open(save_root.joinpath(f"model_{i}.txt"), "w") as model_file:
            model_file.write(str(model))
        with open(save_root.joinpath(f"hparams_{i}.txt"), "w") as model_file:
            model_file.write(f"in_channels={model.in_channels}\n")
            model_file.write(f"features={model.features}\n")
            model_file.write(f"downscales={model.downscales}\n")
            model_file.write(f"padding={model.padding}\n")
            model_file.write(f"non_linearity={model.non_linearity}\n")
            model_file.write(f"final_nonlin={model.final_nonlin}\n")
            model_file.write(f"loss_fn={model.loss_fn}\n")
            model_file.write(f"opt_method={model.opt_method}\n")
            model_file.write(f"lr={model.lr}")

    mc_handler.sendCommand("quit")

In [None]:
in_channels=3, 
features=[
#     [[32, 16, 64], [64, 32, 128]],
    [[64, 32, 128], [128, 64, 256]]
]
downscales=[
#     [False, False],
    [True, False],
#     [True, True]
]
padding=[
#     "zeros",
    "replicate",
#     "reflect",
#     "circular"
]
non_linearity=[
    nn.LeakyReLU,
]
final_nonlin=[
    nn.Sigmoid
]
img_transforms=[
    None,
]
loss_fn=[
    nn.MSELoss(),
]
opt_method=[
    torch.optim.Adagrad
]
lr=[
    1e-3
]

seed = random.randint(0, 2**16)
models = gen_test_models(in_channels=in_channels, 
                         features=features, 
                         downscales=downscales,
                         padding=padding,
                         non_linearity=non_linearity,
                         final_nonlin=final_nonlin,
                         loss_fn=loss_fn,
                         opt_method=opt_method,
                         lr=lr)
print(seed)
test_model(list(models), iters=5001, reset_step=250, show_step=50, seed=seed)
# for m in models
#     print(f"in_channels={m.in_channels}")
#     print(f"features={m.features}")
#     print(f"padding={m.padding}")
#     print(f"non_linearity={m.non_linearity}")
#     print(f"final_nonlin={m.final_nonlin}")
#     print(f"img_transforms={m.img_transforms}")
#     print(f"loss_fn={m.loss_fn}")
#     print(f"opt_method={m.opt_method}")
#     print(f"lr={m.lr}")
#     test_model(m, img_transforms=m.img_transforms, iters=2501)

#     m.cpu()
#     del m
#     gc.collect()
#     torch.cuda.empty_cache()
