In [1]:
import sys
from pathlib import Path
from subprocess import Popen
from json import loads
from datetime import datetime
from time import sleep

from torchvision.transforms import functional as tfunc
import torchviz

import numpy as np
import matplotlib.pyplot as plt


from odin import draw_image, draw_layers, ViewGenerator, ODIN

sys.path.append('../mine_soar')
import MalmoPython

In [2]:
MISSION_PORT = 9001
MISSION_FILE = Path("random_world.xml")
VIDEO_SHAPE = (128, 128)
VIDEO_DEPTH = 3
DEVICE = 'cuda:0'

def launch_malmo_client(mission_port):
    MALMO_ARGS = [f"/home/boggsj/Coding/Malmo-prebuilt/Minecraft/launchClient.sh", "-port", str(mission_port)]
    malmo_proc = Popen(MALMO_ARGS, stdout=sys.stdout)

In [3]:
# launch_malmo_client(MISSION_PORT)

In [4]:
MISSION_FILE = Path("random_world.xml")
malmo_agent_host = MalmoPython.AgentHost()
malmo_client_info = MalmoPython.ClientInfo('127.0.0.1', MISSION_PORT)
malmo_client_pool = MalmoPython.ClientPool()
malmo_client_pool.add(malmo_client_info)

random_world_mission_spec = MalmoPython.MissionSpec(MISSION_FILE.read_text(), True)
random_world_record_spec = MalmoPython.MissionRecordSpec()
malmo_agent_host.startMission(random_world_mission_spec, 
                              malmo_client_pool,
                              random_world_record_spec,
                              0,
                              "TEST ALPHA")

In [5]:
# malmo_agent_host.sendCommand("jump 1")
# malmo_agent_host.sendCommand("turn 0.1")

In [6]:
# malmo_agent_host.sendCommand("jump 0")

In [7]:
# malmo_agent_host.sendCommand("quit")

In [8]:
# for i in range(10):
#     state = malmo_agent_host.getWorldState()
#     vision = np.frombuffer(state.video_frames[0].pixels, dtype=np.uint8)
#     vision = vision.reshape(VIDEO_SHAPE+(VIDEO_DEPTH,))
#     obs = loads(state.observations[0].text)
    
#     if 'LineOfSight' not in obs:
#         obs['LineOfSight'] = {
#             'hitType': 'sky',
#             'x': float('nan'),
#             'y': float('nan'),
#             'z': float('nan'),
#             'type': 'air',
#             'inRange': False,
#             'distance': float('inf')
#         }
#     los = obs['LineOfSight']
    
#     print(f"{i}: {los['hitType']}:{los['type']} @ dist {los['distance']}")
#     plt.imshow(vision)
#     plt.show()
#     plt.close()

In [9]:
CLUSTERER = "kmeans"
LOSS_METHOD = "both"

In [10]:
odin_model = ODIN(None, None, None, lr_tau=0.1, lr_theta=1e-5, lr_xi=0.1, clusters=4, fpn_dim=128)

In [11]:
import torch

In [14]:
def run_training_iteration(input_tensor, eps_coeff, clusterer_type="kmeans", loss_method="contrastive"):
    view_gen = ViewGenerator(input_tensor)
    v0, v1, v2 = view_gen(input_tensor)

    h0, z0 = odin_model.run_tau_network(v0)

    cluster_results = dict()
    total_loss = torch.Tensor([0.]).to(odin_model.device)
    
    best_loss = None
    best_loss_val = float("inf")
    for num_c in [4]:
        clusterer = odin_model.get_clusterer(h0, eps_coeff, num_c, clusterer_type="kmeans")
        cluster_ids, masks, m0, m1, m2 = odin_model.generate_masks(h0, view_gen, clusterer)
        loss, constrastive_loss, closeness_loss = odin_model.run_networks_training_mode(v1, v2, m1, m2, loss_method)
        if loss.item() < best_loss_val:
            best_loss_val = loss.item()
            best_loss = loss
            cluster_results["kmeans"] = (cluster_ids, masks, m0, m1, m2, loss, constrastive_loss, closeness_loss)
        total_loss += loss
    

    return cluster_results, total_loss

In [15]:
session_id = f"{datetime.strftime(datetime.now(), '%Y-%m-%d_%H-%M-%S')}"
session_path = Path(f"./runs/{session_id}/")
if session_path.exists():
    session_path = session_path.joinpath(f"{datetime.strftime(datetime.now(), '%Y-%m-%d_%H-%M-%S')}/")
session_path.mkdir(exist_ok=True)
losses_log_file = session_path.joinpath("losses.txt")
losses_log_file.touch()
losses = []

prev_loss = 0
changes = []

state = malmo_agent_host.getWorldState()
vision = np.frombuffer(state.video_frames[0].pixels, dtype=np.uint8)
vision = vision.reshape(VIDEO_SHAPE+(VIDEO_DEPTH,))
demo_tensor = tfunc.to_tensor(vision).to(DEVICE).unsqueeze(0)
consecutive_skips = 0
odin_model.eps_coeff = 1.0
save_steps = 100

draw_image(demo_tensor)
for j in range(100):
    malmo_agent_host.sendCommand("quit")
    sleep(1)
    malmo_agent_host.startMission(random_world_mission_spec, 
                              malmo_client_pool,
                              random_world_record_spec,
                              0,
                              "TEST ALPHA")
    state = malmo_agent_host.getWorldState()
    while len(state.video_frames) < 1:
        state = malmo_agent_host.getWorldState()
        sleep(0.1)
    malmo_agent_host.sendCommand("jump 1")
    sleep(0.1)
    malmo_agent_host.sendCommand("turn 0.05")
    sleep(0.1)
    malmo_agent_host.sendCommand("jump 0")
    for k in range(25):    
        i=j*25+k
        odin_model.f_tau.train()
        odin_model.g_tau.train()
        odin_model.f_theta.train()
        odin_model.g_theta.train()
        odin_model.q_theta.train()
        odin_model.f_xi.train()
        odin_model.g_xi.train()

        odin_model.f_theta_optim.zero_grad()
        odin_model.g_theta_optim.zero_grad()
        odin_model.q_theta_optim.zero_grad()
        
        state = malmo_agent_host.getWorldState()
        if not state.is_mission_running:
            break
        
        vision = np.frombuffer(state.video_frames[0].pixels, dtype=np.uint8)
        vision = vision.reshape(VIDEO_SHAPE+(VIDEO_DEPTH,))
        input_tensor = tfunc.to_tensor(vision).to(DEVICE).unsqueeze(0)
        try:
            clusters_data, loss = run_training_iteration(input_tensor, odin_model.eps_coeff, CLUSTERER, LOSS_METHOD)
            if "optics" in clusters_data:
                if len(clusters_data["optics"][0]) == 0:
                    # There aren't any clusters, increase epsilon to make it easier for clusters to form
                    consecutive_skips += 1
                    odin_model.eps_coeff += 0.1 * consecutive_skips
                    continue
                if len(clusters_data["optics"][0]) == 1:
                    # There's only one cluster, decrease epsilon to make it easier for clusters to differentiate
                    odin_model.eps_coeff -= 0.1
            else:
                clusters_data["optics"] = [[]]
            if 'kmeans' not in clusters_data:
                clusters_data["kmeans"] = [[]]
            consecutive_skips = 0
            loss_val = loss.item()
            losses.append(loss_val)
            past_10_variance = np.var(losses[-10:])
            
            loss_change = loss_val - prev_loss
            changes.append(loss_change)
            if i > 25:
                changes.pop(0)
            moving_mean_change = np.mean(changes)
            moving_med_change = np.median(changes)
            prev_loss = loss_val

            losses_log_file.open('a').write(f"{i},{loss_val}\n")
            print(f"Iter {i} clust={len(clusters_data['kmeans'][0])}: loss={loss_val:+.10f} delta={loss_change:+.10f} mean={moving_mean_change:+.10f} med={moving_med_change:+.10f} var={past_10_variance:.13f}")
    
        except Exception as e:
            plt.plot(losses)
            plt.savefig(str(session_path.joinpath("loss_graph.png")))
            save_path = session_path.joinpath(f"{i}/")
            save_path.mkdir()
            odin_model.inference(demo_tensor, save_path)
            odin_model.save_parameters(save_path)
            raise e
        except KeyboardInterrupt as e:
            plt.plot(losses)
            plt.savefig(str(session_path.joinpath("loss_graph.png")))
            save_path = session_path.joinpath(f"{i}/")
            save_path.mkdir()
            odin_model.inference(demo_tensor, save_path)
            odin_model.save_parameters(save_path)
            raise e

        if i % save_steps == 0:
            save_path = session_path.joinpath(f"{i}/")
            save_path.mkdir(exist_ok=True)

            if i == 0:
                grad_graph = torchviz.make_dot(loss, dict(odin_model.q_theta.named_parameters()))
                grad_graph.render(save_path.joinpath("grad_graph"))

            odin_model.inference(demo_tensor, save_path)
            odin_model.save_parameters(save_path)

    #     if i >= iterations: break
#         if i > 100 and past_10_variance < 1e-7: break

plt.plot(losses)
plt.savefig(str(session_path.joinpath("loss_graph.png")))
plt.show()

Iter 0 clust=4: loss=+10.6829605103 delta=+10.6829605103 mean=+10.6829605103 med=+10.6829605103 var=0.0000000000000
torch.Size([4, 3, 128, 128])
Iter 1 clust=4: loss=+10.4937839508 delta=-0.1891765594 mean=+5.2468919754 med=+5.2468919754 var=0.0089469426612
Iter 2 clust=4: loss=+10.0850067139 delta=-0.4087772369 mean=+3.3616689046 med=-0.1891765594 var=0.0622705936330
Iter 3 clust=4: loss=+10.7754821777 delta=+0.6904754639 mean=+2.6938705444 med=+0.2506494522 var=0.0703191161817
Iter 4 clust=4: loss=+10.9202747345 delta=+0.1447925568 mean=+2.1840549469 med=+0.1447925568 var=0.0832782335716
Iter 5 clust=4: loss=+10.7751512527 delta=-0.1451234818 mean=+1.7958585421 med=-0.0001654625 var=0.0740828597196
Iter 6 clust=4: loss=+10.4164609909 delta=-0.3586902618 mean=+1.4880658558 med=-0.1451234818 var=0.0686781413546
Iter 7 clust=4: loss=+10.6233739853 delta=+0.2069129944 mean=+1.3279217482 med=-0.0001654625 var=0.0601960728056
Iter 8 clust=4: loss=+10.6549682617 delta=+0.0315942764 mean=+1.

Iter 72 clust=4: loss=+11.2924175262 delta=-0.0013732910 mean=+0.0489923037 med=+0.1279311180 var=0.3249171058635
Iter 73 clust=4: loss=+11.9978513718 delta=+0.7054338455 mean=+0.0713968644 med=+0.2320337296 var=0.2790793464239
Iter 74 clust=4: loss=+10.9919509888 delta=-1.0059003830 mean=+0.0275950432 med=+0.1962556839 var=0.2068483985307
Iter 75 clust=4: loss=+9.7730865479 delta=-1.2188644409 mean=-0.0101183378 med=+0.1962556839 var=0.4803042825136
Iter 76 clust=4: loss=+8.7195968628 delta=-1.0534896851 mean=-0.0648654424 med=+0.0388970375 var=1.0809744822309
Iter 77 clust=4: loss=+10.0561428070 delta=+1.3365459442 mean=-0.0158210168 med=+0.1737618446 var=1.0642138553494
Iter 78 clust=4: loss=+9.1441688538 delta=-0.9119739532 mean=-0.0343239858 med=+0.1737618446 var=1.0103954288808
Iter 79 clust=4: loss=+9.8957071304 delta=+0.7515382767 mean=-0.0501725857 med=+0.1737618446 var=0.9980217408124
Iter 80 clust=4: loss=+9.8304424286 delta=-0.0652647018 mean=-0.1092138290 med=+0.0075149536

Iter 144 clust=4: loss=+10.4375267029 delta=+0.3868675232 mean=-0.0075156138 med=+0.0942039490 var=0.2197544150137
Iter 145 clust=4: loss=+10.4637565613 delta=+0.0262298584 mean=-0.0115400828 med=+0.0458259583 var=0.1428463707955
Iter 146 clust=4: loss=+10.3045368195 delta=-0.1592197418 mean=+0.0211410522 med=+0.0458259583 var=0.1428144366051
Iter 147 clust=4: loss=+10.6480522156 delta=+0.3435153961 mean=-0.0018216647 med=+0.0458259583 var=0.1521070877254
Iter 148 clust=4: loss=+10.0039033890 delta=-0.6441488266 mean=-0.0408985798 med=-0.0469079018 var=0.0832249606670
Iter 149 clust=4: loss=+10.8619937897 delta=+0.8580904007 mean=+0.0190457564 med=+0.0458259583 var=0.0954494231224
Iter 150 clust=4: loss=+10.1356868744 delta=-0.7263069153 mean=-0.0401562911 med=-0.0469079018 var=0.0983555708525
Iter 151 clust=4: loss=+9.7437114716 delta=-0.3919754028 mean=-0.0140240009 med=-0.0469079018 var=0.1207132294680
Iter 152 clust=4: loss=+10.2811479568 delta=+0.5374364853 mean=-0.0161030476 med=

Iter 216 clust=4: loss=+12.0994014740 delta=+1.1538610458 mean=+0.0781546373 med=+0.1208796501 var=0.7600831982896
Iter 217 clust=4: loss=+10.2352619171 delta=-1.8641395569 mean=+0.0092151348 med=+0.1208796501 var=0.4592383759280
Iter 218 clust=4: loss=+11.3825578690 delta=+1.1472959518 mean=+0.0079864722 med=+0.1208796501 var=0.4416761450104
Iter 219 clust=4: loss=+11.1774845123 delta=-0.2050733566 mean=+0.0338597664 med=+0.1208796501 var=0.3972806742666
Iter 220 clust=4: loss=+11.2477111816 delta=+0.0702266693 mean=+0.0371260276 med=+0.1449627876 var=0.2925316578663
Iter 221 clust=4: loss=+10.2847385406 delta=-0.9629726410 mean=-0.0086477353 med=+0.0461435318 var=0.3773322166329
Iter 222 clust=4: loss=+10.6752033234 delta=+0.3904647827 mean=-0.0166253310 med=+0.0461435318 var=0.3886332433761
Iter 223 clust=4: loss=+9.0698490143 delta=-1.6053543091 mean=-0.0465301367 med=+0.0461435318 var=0.7512727273052
Iter 224 clust=4: loss=+9.2001113892 delta=+0.1302623749 mean=-0.0333856803 med=+

Iter 288 clust=4: loss=+9.3132266998 delta=-1.3386440277 mean=-0.0424208274 med=-0.1068196297 var=0.5340561667941
Iter 289 clust=4: loss=+9.9389858246 delta=+0.6257591248 mean=+0.0209668233 med=+0.0597667694 var=0.5549856922179
Iter 290 clust=4: loss=+10.2571411133 delta=+0.3181552887 mean=+0.0150446158 med=+0.0597667694 var=0.4555907008395
Iter 291 clust=4: loss=+9.9142322540 delta=-0.3429088593 mean=-0.0204636500 med=-0.1068196297 var=0.4494443117550
Iter 292 clust=4: loss=+9.2734966278 delta=-0.6407356262 mean=-0.0366423680 med=-0.1068196297 var=0.4906416627466
Iter 293 clust=4: loss=+9.8954439163 delta=+0.6219472885 mean=-0.0306838109 med=-0.1068196297 var=0.3723997029558
Iter 294 clust=4: loss=+11.4246110916 delta=+1.5291671753 mean=+0.0404639978 med=+0.0597667694 var=0.5187295085383
Iter 295 clust=4: loss=+9.9490537643 delta=-1.4755573273 mean=-0.0354453234 med=-0.1068196297 var=0.4012774333015
Iter 296 clust=4: loss=+10.7001161575 delta=+0.7510623932 mean=+0.0364770156 med=+0.05

KeyboardInterrupt: 

In [None]:
state = malmo_agent_host.getWorldState()
vision = np.frombuffer(state.video_frames[0].pixels, dtype=np.uint8)
vision = vision.reshape(VIDEO_SHAPE+(VIDEO_DEPTH,))
input_tensor = tfunc.to_tensor(vision).to(DEVICE).unsqueeze(0)

In [None]:
odin_model.f_tau.train()
odin_model.g_tau.train()
odin_model.f_theta.train()
odin_model.g_theta.train()
odin_model.q_theta.train()
odin_model.f_xi.train()
odin_model.g_xi.train()

odin_model.f_theta_optim.zero_grad()
odin_model.g_theta_optim.zero_grad()
odin_model.q_theta_optim.zero_grad()

In [None]:
view_gen = ViewGenerator(input_tensor)

In [None]:
v0, v1, v2 = view_gen(input_tensor)

In [None]:
h0, z0 = odin_model.run_tau_network(v0)

In [None]:
clusterer = odin_model.get_clusterer(h0, 1.0, 16, clusterer_type="kmeans")
cluster_ids, masks, m0, m1, m2 = odin_model.generate_masks(h0, view_gen, clusterer)

In [None]:
m1.shape

In [None]:
(h1_theta, h2_theta), (masked_h1_theta, masked_h2_theta), (hk1_theta, hk2_theta), (zk1_theta, zk2_theta), (pk1_theta, pk2_theta) = odin_model.run_theta_network(v1, v2, m1, m2)

In [None]:
(h1_xi, h2_xi), (masked_h1_xi, masked_h2_xi), (hk1_xi, hk2_xi), (zk1_xi, zk2_xi) = odin_model.run_xi_network(v1, v2, m1, m2)

In [None]:
contrastive_loss = odin_model.total_contrastive_loss(pk1_theta, pk2_theta, zk1_xi, zk2_xi, 0.1)
# contrastive_loss =(pk1_theta-zk2_xi).abs().sum()

In [None]:
contrastive_loss.item()

In [None]:
initial_q_theta_sd = dict([(l,p.detach().cpu().numpy()) for l,p in odin_model.q_theta.state_dict().items()])
initial_g_theta_sd = dict([(l,p.detach().cpu().numpy()) for l,p in odin_model.g_theta.state_dict().items()])
initial_f_theta_sd = dict([(l,p.detach().cpu().numpy()) for l,p in odin_model.f_theta.state_dict().items()])

In [None]:
print(initial_f_theta_sd)

In [None]:
loss_backwards = contrastive_loss.backward(retain_graph=True)
odin_model.q_theta_optim.step()
odin_model.g_theta_optim.step()
odin_model.f_theta_optim.step()

In [None]:
print("q")
for l, p in dict([(l,p.detach().cpu().numpy()) for l,p in odin_model.q_theta.state_dict().items()]).items():
    initial_val = initial_q_theta_sd[l]
    diff = p-initial_val
    if diff.sum() != 0:
        print(l)
#         print(diff)
#         print("=========================================================")

print("g")
for l, p in dict([(l,p.detach().cpu().numpy()) for l,p in odin_model.g_theta.state_dict().items()]).items():
    initial_val = initial_g_theta_sd[l]
    diff = p-initial_val
    if diff.sum() != 0:
        print(l)
#         print(diff)
#         print("=========================================================")

print("f")
for l, p in dict([(l,p.detach().cpu().numpy()) for l,p in odin_model.f_theta.state_dict().items()]).items():
    initial_val = initial_f_theta_sd[l]
    diff = p-initial_val
    if diff.sum() != 0:
        print(l)
#         print(diff)
#         print("=========================================================")

In [None]:
dict([(l,p.detach().cpu().numpy()) for l,p in odin_model.q_theta.state_dict()]).items()

In [None]:
list(odin_model.f_xi.parameters())

In [None]:
initial_f_theta_params[0].detach().cpu().numpy()