In [None]:
# import os
# os.chdir('./src/terrain_mlp/fft_new/training')

In [None]:
import numpy as np
import torch as th
from tqdm import tqdm,trange    
from torch import Tensor
from torch.utils.data import Dataset, DataLoader, random_split

import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter
from models.flow_model_v7 import Flow
import open3d as o3d


In [None]:
th.set_float32_matmul_precision('high')

writer = SummaryWriter("./logs")
test = 'v7_500'

In [None]:
# %matplotlib widget

In [None]:
BATCH_SIZE      = 500
LEARNING_RATE   = 1e-3
SEED            = 0
DEVICE          = 'cuda'      
NUM_EPOCH       = 2000

In [None]:
th.manual_seed(seed=SEED)
th.cuda.manual_seed(seed=SEED)

In [None]:
class TrajDataset(Dataset):
    def __init__(self, x_fin_data, y_fin_data, theta_init_data, theta_fin_data, x_data, y_data, index_data):
        
        # goal
        self.x_fin = x_fin_data
        self.y_fin = y_fin_data
        self.theta_init = theta_init_data
        self.theta_fin = theta_fin_data
        # gt values
        self.gt_x = x_data
        self.gt_y = y_data
        # index
        self.index = index_data
        
    
    def __len__(self):
        return len(self.x_fin)    
            
    def __getitem__(self, idx):
    
        x_fin = self.x_fin[idx] 
        y_fin = self.y_fin[idx] 
        theta_init = self.theta_init[idx] 
        theta_fin = self.theta_fin[idx] 
        gt_x = self.gt_x[idx] 
        gt_y = self.gt_y[idx] 
        index = self.index[idx]
                 
        return th.tensor(x_fin).float(), th.tensor(y_fin).float(), th.tensor(theta_init).float(), th.tensor(theta_fin).float(), th.tensor(gt_x).float(), \
            th.tensor(gt_y).float(), index
    

In [None]:
data_set = np.load("./dataset/data_train_pcd_gd.npz")

lam_data = data_set["lam"]
p1_data = data_set["p1"]
p2_data = data_set["p2"]
p3_data = data_set["p3"]
p4_data = data_set["p4"]
cov_data = data_set["cov"]
x_fin_data = data_set["x_fin"]
y_fin_data = data_set["y_fin"]
theta_init_data = data_set["theta_init"]
theta_fin_data = data_set["theta_fin"]
x_data = data_set["x"]
y_data = data_set["y"]
pcd_data = data_set["pcd"]
index_data = data_set["index"]

terrain_params = np.concatenate((p1_data, p2_data, p3_data, p4_data), axis=1)

theta_init_mean, theta_init_std = th.tensor(theta_init_data.mean()).to(DEVICE), th.tensor(theta_init_data.std()).to(DEVICE)
theta_fin_mean, theta_fin_std = th.tensor(theta_fin_data.mean()).to(DEVICE), th.tensor(theta_fin_data.std()).to(DEVICE)
x_fin_mean, x_fin_std = th.tensor(x_fin_data.mean()).to(DEVICE), th.tensor(x_fin_data.std()).to(DEVICE)
y_fin_mean, y_fin_std = th.tensor(y_fin_data.mean()).to(DEVICE), th.tensor(y_fin_data.std()).to(DEVICE)
lam_mean, lam_std = th.tensor(lam_data.mean()).to(DEVICE), th.tensor(lam_data.std()).to(DEVICE)
terrain_params_mean, terrain_params_std = th.tensor(terrain_params.mean()).to(DEVICE), th.tensor(terrain_params.std()).to(DEVICE)
cov_mean, cov_std = th.tensor(cov_data.mean()).to(DEVICE), th.tensor(cov_data.std()).to(DEVICE)

# #downsample pcd
# N = 15000
# down_pcd = np.zeros((pcd_data.shape[0], N, 3), dtype=np.float32)
# for i in range(pcd_data.shape[0]):
#     current_pcd = pcd_data[i,:,:]
#     indices = np.random.choice(len(current_pcd), N, replace=False)
#     down_pcd[i,:,:] = current_pcd[indices]
# pcd_data = down_pcd.transpose(0,2,1)

pcd_data = pcd_data.transpose(0,2,1)

dataset = TrajDataset( x_fin_data, y_fin_data, theta_init_data, theta_fin_data, x_data, y_data, index_data)
print(len(dataset))

train_size = int(0.9 * len(dataset))  
test_size = len(dataset) - train_size  

# Create a generator with a fixed seed
generator = th.Generator().manual_seed(0)

train_dataset, test_dataset = random_split(dataset, [train_size, test_size], generator=generator)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=False)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=False)


In [None]:
out_chan = 512

flow = Flow(out_chan,theta_init_mean,theta_init_std,theta_fin_mean,theta_fin_std,x_fin_mean,x_fin_std,y_fin_mean,y_fin_std,lam_mean,lam_std,
            terrain_params_mean,terrain_params_std,cov_mean,cov_std).cuda()

loss_fn = th.nn.MSELoss()
optimizer = th.optim.AdamW(flow.parameters(), lr=LEARNING_RATE)
# scheduler = th.optim.lr_scheduler.StepLR(optimizer, step_size = 1000, gamma = 0.1)

# flow.load_state_dict(th.load(f"./weights/test_v5_4.pt"))
# optimizer.load_state_dict(th.load(f"./opts/test_v5_4.pt"))
# for param_group in optimizer.param_groups:
#     param_group['lr'] = LEARNING_RATE

flow.train()
c_flow = th.compile(flow)

In [None]:
avg_losses = []
last_loss = th.inf

for epoch in trange(NUM_EPOCH):
	losses = []
	
	for (x_fin,y_fin,theta_init,theta_fin,gt_x,gt_y,index) in train_loader:

		x_fin = x_fin.to(DEVICE)
		y_fin = y_fin.to(DEVICE)
		theta_init = theta_init.to(DEVICE)
		theta_fin = theta_fin.to(DEVICE)
		gt_x = gt_x.to(DEVICE)
		gt_y = gt_y.to(DEVICE)
		lam = th.tensor(lam_data[index]).float().to(DEVICE)
		p1 = th.tensor(p1_data[index]).float().to(DEVICE)
		p2 = th.tensor(p2_data[index]).float().to(DEVICE)
		p3 = th.tensor(p3_data[index]).float().to(DEVICE)
		p4 = th.tensor(p4_data[index]).float().to(DEVICE)
		cov = th.tensor(cov_data[index]).float().to(DEVICE)
		pcd = th.tensor(pcd_data[index]).float().to(DEVICE)

		terrain_data = [theta_init, theta_fin, x_fin, y_fin, lam, th.hstack([p1,p2,p3,p4]), cov]

		x_1 = th.stack([gt_x,gt_y], dim=1)
		x_0 = th.randn_like(x_1)
		t = th.rand(len(x_1), 1, 1).to(device=DEVICE)
		x_t = (1 - t) * x_0 + t * x_1
		dx_t = x_1 - x_0
	
		loss = loss_fn(c_flow(x_t, terrain_data, t, pcd), dx_t)
		losses.append(loss.item())

		optimizer.zero_grad()
		loss.backward()
		optimizer.step()


	mean_loss = np.mean(losses)
	avg_losses.append(mean_loss)
	# scheduler.step()

	if epoch % 50 == 0:
		print(f"Epoch: {epoch + 1}, Train Loss: {mean_loss:.3f}")
	writer.add_scalar('test_{}'.format(test), loss, epoch)

	if loss <= last_loss:
		th.save(flow.state_dict(), f"./weights/test_{test}_lowest.pt")
		th.save(optimizer.state_dict(), f"./opts/test_{test}_lowest.pt")
		last_loss = loss

In [None]:
th.save(flow.state_dict(), './weights/test_{}.pt'.format(test))
th.save(optimizer.state_dict(),'./opts/test_{}.pt'.format(test))

np.savez(
    "./out_data/data_out_test_{}.npz".format(test),
    theta_init_mean=theta_init_mean.detach().cpu().numpy(),
    theta_init_std=theta_init_std.detach().cpu().numpy(),
    theta_fin_mean=theta_fin_mean.detach().cpu().numpy(),
    theta_fin_std=theta_fin_std.detach().cpu().numpy(),
    x_fin_mean=x_fin_mean.detach().cpu().numpy(),
    x_fin_std=x_fin_std.detach().cpu().numpy(),
    y_fin_mean=y_fin_mean.detach().cpu().numpy(),
    y_fin_std=y_fin_std.detach().cpu().numpy(),
    lam_mean=lam_mean.detach().cpu().numpy(),
    lam_std=lam_std.detach().cpu().numpy(),
    terrain_params_mean=terrain_params_mean.detach().cpu().numpy(),
    terrain_params_std=terrain_params_std.detach().cpu().numpy(),
    cov_mean=cov_mean.detach().cpu().numpy(),
    cov_std=cov_std.detach().cpu().numpy(),
    avg_losses=np.array(avg_losses),
    batch_size=BATCH_SIZE,
    num_epoch=NUM_EPOCH,
    learning_rate=LEARNING_RATE
)

In [None]:
flow.load_state_dict(th.load(f"./weights/test_{test}.pt"))
# flow.load_state_dict(th.load(f"./weights/test_{test}_lowest.pt"))

flow.eval()

x_fin,y_fin,theta_init,theta_fin,gt_x,gt_y,index = next(iter(train_loader))

lam = th.tensor(lam_data[index]).float().to(DEVICE)
p1 = th.tensor(p1_data[index]).float().to(DEVICE)
p2 = th.tensor(p2_data[index]).float().to(DEVICE)
p3 = th.tensor(p3_data[index]).float().to(DEVICE)
p4 = th.tensor(p4_data[index]).float().to(DEVICE)
cov = th.tensor(cov_data[index]).float().to(DEVICE)
pcd = th.tensor(pcd_data[index]).float().to(DEVICE)
x_fin = x_fin.to(DEVICE)
y_fin = y_fin.to(DEVICE)
theta_init = theta_init.to(DEVICE)
theta_fin = theta_fin.to(DEVICE)
gt_x = gt_x.to(DEVICE)
gt_y = gt_y.to(DEVICE)

terrain_data = [theta_init, theta_fin, x_fin, y_fin, lam, th.hstack([p1,p2,p3,p4]), cov]

n_steps = 16
x_1 = th.stack([gt_x,gt_y], dim=1)
x_0 = th.randn_like(x_1)
time_steps = th.linspace(0, 1.0, n_steps + 1, device=DEVICE).unsqueeze(1).unsqueeze(2)

x = x_0.clone()
fig = plt.figure(figsize=(15, 15)) 
for t in range(n_steps):
	ax = fig.add_subplot(4, 4, t+1)
	
	with th.inference_mode():
		x = flow.step(x, terrain_data, time_steps[t].expand(x_0.shape[0],-1,-1), time_steps[t + 1].expand(x_0.shape[0],-1,-1), pcd)

	ax.scatter(x[0, 0, :].detach().cpu().numpy(), x[0, 1, :].detach().cpu().numpy())
	ax.plot(0, 0, 'og', markersize=5)
	ax.plot(x_fin[0].detach().cpu().numpy(), y_fin[0].detach().cpu().numpy(), 'or', markersize=5)
	ax.plot(gt_x[0, :].detach().cpu().numpy(), gt_y[0, :].detach().cpu().numpy(), 'r')

	ax.axis('equal')
	ax.set_title(f'timestep: {t}')
	ax.grid("both", linewidth=0.5)

plt.tight_layout()
plt.show()