In [None]:
%load_ext autoreload
%autoreload 2
import matplotlib.pyplot as plt
import numpy as np
from typing import Tuple
import torch
from torch import nn


In [None]:
import os, sys
root_path = os.path.realpath('../')
sys.path.append(root_path)
from ego_allo_rnns.utils.utils import front_frame, input_frame
from ego_allo_rnns.data.EgoVsAllo import make_datasets
from ego_allo_rnns.models.rnns import RNN

In [None]:
# set params
coordinate_type = "Cartesian"
input_type = "SC"
label_type = "WC"
title = f"Input type: {input_type}, Label type: {label_type}, {coordinate_type} coordinates"


### 1. understand utils.front_frame

In [None]:
def show_frame(frames: np.ndarray, coords: Tuple, idx: int):
    plt.imshow(frames[idx,:,:])
    plt.scatter(start_poke_coordinate[1,idx],coords[0][0,idx],s=400,marker='o',edgecolor='lightgreen',facecolor="none",alpha=1)
    plt.scatter(target_poke_coordinate[1,idx],coords[1][0,idx],s=400,marker='o',edgecolor='yellow',facecolor="none",alpha=1)
    

In [None]:
from utils import random_poke_generator
samples = random_poke_generator(5,3)
samples

In [None]:
# understand front_frame
frames, start_poke_coordinate, target_poke_coordinate = front_frame(
    random_seed=20, frame_amount=5000)

In [None]:
idx = 0
show_frame(frames,(start_poke_coordinate, target_poke_coordinate), idx)


In [None]:
np.unique(frames[idx,:,:])

In [None]:
idx = 2
show_frame(frames, (start_poke_coordinate, target_poke_coordinate), idx)

### 2. understand utils.input_frame

In [None]:
#understand input_frame
idx = 0
x_train = input_frame(frames, input_type, start_poke_coordinate)
plt.imshow(x_train[idx,:,:])
plt.scatter(start_poke_coordinate[1,idx]+30,start_poke_coordinate[0,idx]+16,s=400,marker='o',edgecolor='lightgreen',facecolor="none",alpha=1)
plt.scatter(target_poke_coordinate[1,idx]+30,target_poke_coordinate[0,idx]+16,s=400,marker='o',edgecolor='yellow',facecolor="none",alpha=1)

In [None]:
f, axs = plt.subplots(3,3,figsize=(2,2),dpi=300)
x_train = input_frame(frames, output_type="WC", start_poke=start_poke_coordinate)
axs = axs.ravel()
for ax,img in zip(axs,x_train):
    ax.imshow(img)
    ax.set_axis_off()
f.suptitle("World-centered starting location", fontsize=6)


In [None]:
f, axs = plt.subplots(3,3,figsize=(2,2),dpi=300)
x_train = input_frame(frames, output_type="SC", start_poke=start_poke_coordinate)
axs = axs.ravel()
for ax,img in zip(axs,x_train):
    ax.imshow(img)
    ax.set_axis_off()
f.suptitle("Self-centered starting location", fontsize=6)

### 3. understand labels

In [None]:
unique_locs = {}
for xy in start_poke_coordinate.T:
    if not str(xy) in unique_locs:
        unique_locs[str(xy)] = 1
print(len(unique_locs.keys()))



In [None]:
unique_locs = {}
for xy in target_poke_coordinate.T:
    if not str(xy) in unique_locs:
        unique_locs[str(xy)] = 1
print(len(unique_locs.keys()))


## Sanity check

In [None]:
data_train, data_test = make_datasets(size_ds=40)

In [None]:
f,axs = plt.subplots(3,4,figsize=(10,10))
axs = axs.flatten()

for i,ax in enumerate(axs[:-1]):
    
    im = ax.imshow(data_train[0][4,i,:].reshape(40,40))
    im.set_clim(0,1)
    ax.set_axis_off()

axs[-1].set_axis_off()

In [None]:
# try to overfit on training data:
x_train = torch.tensor(data_train[0][0:10,:],dtype=torch.float).to("cuda:0")
y_train = torch.tensor(data_train[1][0:10,:],dtype=torch.float).to("cuda:0")


In [None]:
rnn = RNN(input_size=1600,hidden_size=5).to("cuda:0")
optimiser = torch.optim.Adam(rnn.parameters(),lr=1e-2)
criterion = nn.MSELoss()
mses = []
hidden_act = []
for epoch in range(3000):
    y_ = rnn(x_train)
    loss = criterion(y_,y_train)
    optimiser.zero_grad()
    loss.backward()
    optimiser.step()
    mses.append(loss.cpu().detach().numpy())
    hidden_act.append(rnn.hidden_states)

In [None]:
plt.plot(mses)

In [None]:
plt.plot(y_train.cpu().detach().numpy(),y_.cpu().detach().numpy(),'r*')