In [16]:
import torch
import torch.optim as optim
import torch.nn as nn
import torch.utils
import torch.utils.data
import numpy as np
from PIL import Image

import ds
import importlib

os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
np.random.seed(0)

In [2]:
importlib.reload(ds)

MINERL_GYM_ENV   = "MineRLTreechopVectorObf-v0"
MINERL_DATA_ROOT = os.getenv('MINERL_DATA_ROOT')
assert MINERL_DATA_ROOT is not None

train_names, valid_names, test_names = ds.create_train_valid_test_split(MINERL_GYM_ENV, train_percent=.8, valid_percent=.1)
print(f"Split dataset into train:{len(train_names)} valid:{len(valid_names)} test:{len(test_names)}")

Split dataset into train:168 valid:21 test:20


In [3]:
importlib.reload(ds)

root_ds_path = os.path.join(MINERL_DATA_ROOT, MINERL_GYM_ENV)
assert os.path.exists(root_ds_path)
root_ds_path = os.path.join(root_ds_path, "frame_dump")

train_ds_dir = ds.ensure_dataset("train", train_names, root_ds_path)
valid_ds_dir = ds.ensure_dataset("valid", valid_names, root_ds_path)
test_ds_dir  = ds.ensure_dataset("test",  test_names,  root_ds_path)

print("Train",      train_ds_dir)
print("Validation", valid_ds_dir)
print("Test",       test_ds_dir)

Train /home/basidio/Development/omscs/dl/DiamondsInTheRough/setup/../raw_data/MineRLTreechopVectorObf-v0/frame_dump/train
Validation /home/basidio/Development/omscs/dl/DiamondsInTheRough/setup/../raw_data/MineRLTreechopVectorObf-v0/frame_dump/valid
Test /home/basidio/Development/omscs/dl/DiamondsInTheRough/setup/../raw_data/MineRLTreechopVectorObf-v0/frame_dump/test


In [4]:
importlib.reload(ds)

train_ds = ds.MineRLFrameDataset(train_ds_dir)
valid_ds = ds.MineRLFrameDataset(valid_ds_dir)
test_ds  = ds.MineRLFrameDataset(test_ds_dir)

print(f"Train({len(train_ds)}), Valid({len(valid_ds)}), Test({len(test_ds)})")

Train(352680), Valid(43915), Test(44729)


In [33]:
import train
import model_zoo
importlib.reload(model_zoo)
importlib.reload(train)

model = model_zoo.ConvAutoEncoder(n_input_channels=3)
if torch.cuda.is_available():
    print(f"Moving model to cuda")
    model = model.cuda()

crit  = nn.BCELoss()
opt   = torch.optim.Adam(model.parameters(), lr=0.001)

train_loader = torch.utils.data.DataLoader(train_ds, batch_size=128, num_workers=0, shuffle=True)
valid_loader = torch.utils.data.DataLoader(valid_ds, batch_size=128, num_workers=0, shuffle=True)

# LOG_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "logs")
LOG_PATH = os.path.join("/home/basidio/Development/omscs/dl/DiamondsInTheRough/conv_auto_encoder", "logs")
def valid_callback(epoch_i, nograd_model, x_in, out_dir=LOG_PATH):
    x_in = train.preprocess(nograd_model, torch.tensor((np.array([train_ds[0]]))))
    x_out = nograd_model(x_in)
    print(x_out.shape)
    first = x_out[0].cpu().detach().numpy()
    first = np.moveaxis(first, 0, -1)

    os.makedirs(out_dir, exist_ok=True)

    print(first.shape)
    im = Image.fromarray(np.uint8(first * 255))
    out_path = os.path.join(out_dir, f"{epoch_i}.jpg")
    im.save(out_path)

train.train(model, crit, opt, 150, train_loader, valid_loader, valid_epoch_freq=10, valid_callback=valid_callback)

  1%|▏         | 2/150 [00:00<00:10, 13.52it/s]Moving model to cuda
torch.Size([1, 3, 64, 64])
(64, 64, 3)
0, 0.032437, 0.032389
1, 0.032385, -inf
  3%|▎         | 4/150 [00:00<00:09, 14.79it/s]2, 0.032351, -inf
3, 0.032305, -inf
4, 0.032275, -inf
5, 0.032229, -inf
  6%|▌         | 9/150 [00:00<00:08, 16.94it/s]6, 0.032190, -inf
7, 0.032165, -inf
8, 0.032136, -inf
9, 0.032091, -inf
  9%|▉         | 14/150 [00:00<00:08, 16.91it/s]torch.Size([1, 3, 64, 64])
(64, 64, 3)
10, 0.032057, 0.032019
11, 0.032000, -inf
12, 0.031977, -inf
13, 0.031929, -inf
 13%|█▎        | 19/150 [00:01<00:07, 18.20it/s]14, 0.031870, -inf
15, 0.031819, -inf
16, 0.031746, -inf
17, 0.031677, -inf
18, 0.031596, -inf
 14%|█▍        | 21/150 [00:01<00:07, 16.30it/s]19, 0.031556, -inf
torch.Size([1, 3, 64, 64])
(64, 64, 3)
20, 0.031435, 0.031317
21, 0.031323, -inf
22, 0.031228, -inf
 17%|█▋        | 26/150 [00:01<00:06, 17.78it/s]23, 0.031083, -inf
24, 0.030929, -inf
25, 0.030724, -inf
26, 0.030599, -inf
 20%|██       