In [None]:
import os
import torch
import torchvision
from torch import nn
from torch import optim
from torch.utils.data import Dataset
from torchvision import datasets, transforms
from torchvision.io import read_image
from PIL import Image

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
DATA_PATH = '/content/drive/MyDrive/diploma'
DATASET_DIR = '/dataset-v3'
IMGS_DIR = '/imgs'
IMGS_PATH = DATA_PATH + DATASET_DIR + IMGS_DIR

In [None]:
import json

def parse_row_v2(row):
  devided = row.split(" --&#!--- ")
  target_part = devided[-1]
  target = json.loads(target_part)
  return target


def parse_actions_file(rows):
  actions = []
  for row in rows:
    if not " --&#!--- " in row:
      continue
    actions.append(parse_row_v2(row))
  return actions


def reformat_actions_into_dict(parsed_rows):
  result = {}
  for row in parsed_rows:
    result[row['t']] = row
  return result


def episode_actions_dict(episode_number):
  with open(DATA_PATH + DATASET_DIR + "/" + str(episode_number), "r") as f:
    raw_data = f.readlines()
  parsed_actions = parse_actions_file(raw_data)
  actions_dict = reformat_actions_into_dict(parsed_actions)
  return actions_dict


def episode_images_list(episode_number):
  return os.listdir(IMGS_PATH + str(episode_number))


def merge_actions_and_images(actions, images_list, images_path):
  SIGNS_TO_COMPARE = 11
  def to_short_time(t):
    return int(str(t)[:SIGNS_TO_COMPARE])

  result = {}
  actions_count = len(actions)
  images_count = len(images_list)
  images_short = {}

  for i in images_list:
    key = to_short_time(i)
    if not key in images_short:
      images_short[key] = i

  for action_time in actions.keys():
    short_time = to_short_time(action_time)
    if (short_time in images_short):
      the_item = actions[action_time]
      the_item['img'] = images_path + '/' + images_short[short_time]
      result[short_time] = the_item
  return result


def reduce_frequency_by_step(episode_data, step):
  new_episode_data = {}
  sorted_keys = sorted(episode_data.keys())
  for i in range(step, len(sorted_keys), 1):
    new_elem = episode_data[sorted_keys[i-step]].copy()
    next_elem = episode_data[sorted_keys[i]]
    new_elem['t'] = next_elem['t']
    new_elem['j'] = next_elem['j']
    new_episode_data[sorted_keys[i - step]] = new_elem
  for i in range(-step, 0, 1):
    new_elem = episode_data[sorted_keys[i]].copy()
    next_elem = episode_data[sorted_keys[-1]]
    new_elem['t'] = next_elem['t']
    new_elem['j'] = next_elem['j']
    new_episode_data[sorted_keys[i]] = new_elem
  return new_episode_data


def with_action_horizon(reduced, action_horizon):
  reduced_keys = sorted(list(reduced.keys()))
  len_reduced = len(reduced)
  result = {}

  for i in range(len_reduced):
    to_append_to_current = []
    current = reduced[reduced_keys[i]]['j'].copy()

    for j in range(1, action_horizon):
      to_append = current.copy()
      if (i + j) < len_reduced:
        to_append = reduced[reduced_keys[i + j]]['j'].copy()

      to_append_to_current += to_append.copy()

    result[reduced_keys[i]] = {
        'j': reduced[reduced_keys[i]]['j'].copy() + to_append_to_current.copy(),
        'current': reduced[reduced_keys[i]]['current'].copy(),
        't': reduced[reduced_keys[i]]['t'],
        'img': reduced[reduced_keys[i]]['img']
        }
  return result

def form_data_for_episode(episode_number, reduction_step, action_horizon):
  actions_dict = episode_actions_dict(episode_number)
  images_list = episode_images_list(episode_number)
  episode_data = merge_actions_and_images(actions_dict, sorted(images_list), IMGS_PATH + str(episode_number))
  reduced_data = reduce_frequency_by_step(episode_data, reduction_step)
  data_with_action_horizon = with_action_horizon(reduced_data, action_horizon)
  return data_with_action_horizon


def episodes_data(numbers_of_episodes, reduction_step, action_horizon):
  result = {}
  for i in numbers_of_episodes:
    result.update(form_data_for_episode(i, reduction_step, action_horizon))
  return result

In [None]:
class CustomImageDataset(Dataset):
    def __init__(self, data, image_transform, action_transform):
        self.data = list(data.values())
        self.image_transform = image_transform
        self.action_transform = action_transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]

        image = Image.open(sample['img'])
        image = self.image_transform(image)

        action = self.action_transform(sample['j'])
        current = self.action_transform(sample['current'])
        return image, torch.tensor(action), torch.tensor(current)

In [None]:
class DenseLayer(nn.Module):

    def __init__(self, in_channels, growth_rate):
        super().__init__()
        self.simple_layer = nn.Sequential(
            nn.BatchNorm2d(num_features=in_channels),
            nn.Conv2d(in_channels=in_channels , out_channels=4*growth_rate,
                      kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(num_features=4*growth_rate),
            nn.Conv2d(in_channels=4*growth_rate, out_channels=growth_rate,
                      kernel_size=3, stride=1, padding=1, bias = False),
            nn.ReLU()
        )

    def forward(self,x):
        xin = x
        xout = self.simple_layer(x)
        return torch.cat((xin, xout), 1)

class DenseBlock(nn.Module):
    def __init__(self, in_channels, number_of_dence_layers, growth_rate):

        super().__init__()
        self.number_of_dence_layers = number_of_dence_layers
        self.simple_block = nn.Sequential(
            *[DenseLayer(in_channels + growth_rate * i, growth_rate) for i in range(self.number_of_dence_layers)]
        )

    def forward(self,x):
        return self.simple_block(x)

class TransitionLayer(nn.Module):
    def __init__(self, in_channels, compression_factor):

        super().__init__()
        self.transition_layer = nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.Conv2d(in_channels, int(in_channels * compression_factor),
                      kernel_size=1, stride=1, padding=0, bias=False),
            nn.AvgPool2d(kernel_size=2, stride=2)
        )


    def forward(self, x):
        return self.transition_layer(x)

class DenseNet(nn.Module):
    def __init__(self, in_channels, densenet_config, head_n, growth_rate, compression_factor):

        super().__init__()

        start_channels = 64

        self.prepare_block = nn.Sequential(
            nn.Conv2d(in_channels, start_channels,
                      kernel_size=7, stride=2, padding=3, bias = False),
            nn.BatchNorm2d(start_channels),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        next_in_channels = start_channels
        main_layers = []
        for i, block_size in enumerate(densenet_config[:-1]):

            main_layers.append(DenseBlock(next_in_channels, block_size, growth_rate))
            next_in_channels  = int(next_in_channels + growth_rate * block_size)

            main_layers.append(TransitionLayer(next_in_channels, compression_factor))
            next_in_channels = int(next_in_channels * compression_factor)

        self.main_part = nn.Sequential(*main_layers)

        last_layers = []
        last_layers.append(DenseBlock(next_in_channels, densenet_config[-1], growth_rate))
        next_in_channels  = int(next_in_channels +  + growth_rate * densenet_config[-1])
        last_layers.append(nn.BatchNorm2d(next_in_channels))
        last_layers.append(nn.AdaptiveAvgPool2d(1))
        self.last_block = nn.Sequential(*last_layers)

        self.head = nn.Linear(next_in_channels, head_n)


    def forward(self, x):
        x = self.prepare_block(x)
        x = self.main_part(x)
        x = self.last_block(x)
        x = torch.flatten(x, start_dim=1)
        x = self.head(x)
        return x


In [None]:
class CVAE(nn.Module):
  def __init__(self, action_horizon):
    super().__init__()

    self.hidden_size = 64
    self.action_horizon = action_horizon
    self.condition_size = 1024

    self.encoder = nn.Sequential(
        nn.Linear(7, 1024),
        nn.ReLU(),
        nn.Linear(1024, 512),
        nn.ReLU(),
        nn.Linear(512, 256),
        nn.ReLU(),
    )

    self.mean_layer = nn.Linear(256, self.hidden_size)
    self.logvar_layer = nn.Linear(256, self.hidden_size)

    self.feature_exctractor = DenseNet(3, [6, 12, 24, 16], 512, 32, 0.5)
    self.joint_projector = nn.Linear(7, 512)

    self.decoder = nn.Sequential(
        nn.Linear(self.condition_size + self.hidden_size, self.hidden_size),
        nn.ReLU(),
        nn.Linear(self.hidden_size, 256),
        nn.ReLU(),
        nn.Linear(256, 512),
        nn.ReLU(),
        nn.Linear(512, 1024),
        nn.ReLU(),
        nn.Linear(1024, 7 * self.action_horizon),
    )

  def reparametrize(self, mu, logvar):
    std = torch.exp(0.5 * logvar)
    random_sample = torch.randn_like(std)
    return std * random_sample + mu

  def decode(self, z, image, joints):
    image_vector = self.feature_exctractor(image)
    joint_vector = self.joint_projector(joints)
    concat_vector = torch.cat((z, image_vector, joint_vector), dim=-1)
    return self.decoder(concat_vector)

  def forward(self, x, image, joints):
    latent_representation = self.encoder(x)

    mu = self.mean_layer(latent_representation)
    logvar = self.logvar_layer(latent_representation)

    z = self.reparametrize(mu, logvar)

    return latent_representation, mu, logvar, self.decode(z, image, joints)


In [None]:
def train_visuomotor_policy(
    model, dataset, learning_rate, batch_size, num_of_epochs, device, shuffle,
    validation_set, loss_list, val_loss_list):

  RECONSTRUCTION_COEF = 0.75
  KLD_COEF = 0.25
  optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
  criterion = nn.L1Loss(reduction="mean")
  dataloader = torch.utils.data.DataLoader(
      dataset=dataset, batch_size=batch_size, shuffle=shuffle)
  validation_loader = torch.utils.data.DataLoader(
      dataset=validation_set, batch_size=batch_size, shuffle=True)

  model = model.to(device)
  model.train()

  for epoch in range(num_of_epochs):
    reconstruction_loss_sum = 0
    batches_count = 0
    for data in dataloader:
      batches_count += 1
      images, labels, currents = data
      images = images.to(device)
      labels = labels.to(device)
      currents = currents.to(device)

      encoded, mu, log_var, decoded = model(labels, images, currents)

      reconstruction_loss = criterion(decoded, labels)
      kld_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())

      loss = RECONSTRUCTION_COEF * reconstruction_loss + KLD_COEF * kld_loss

      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      reconstruction_loss_sum += reconstruction_loss.item()

    epoch_loss = reconstruction_loss_sum / batches_count
    loss_list.append(epoch_loss)

    model.eval()
    v_loss = 0
    with torch.no_grad():
      val_criterion = nn.L1Loss(reduction="mean")
      val_batches_count = 0
      for val_data in validation_loader:
        val_batches_count += 1
        v_images, v_labels, v_currents = val_data
        v_images = v_images.to(device)
        v_labels = v_labels.to(device)
        v_currents = v_currents.to(device)

        v_encoded, v_mu, v_log_var, v_decoded = model(v_labels, v_images, v_currents)

        v_reconstruction_loss = criterion(v_decoded, v_labels)

        v_loss += v_reconstruction_loss.item()

      val_loss = v_loss / val_batches_count
      val_loss_list.append(val_loss)

    model.train()

    print(f'epoch \t{epoch}\t -- train loss: {epoch_loss} -- val loss: {val_loss}')

  return model

In [None]:
image_transformator = transforms.Compose([
    # transforms.Grayscale(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    # transforms.ColorJitter(0.4, 0.4, 0.0, 0.3),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

test_image_transformator = transforms.Compose([
    # transforms.Grayscale(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

def action_transformator(action):
  return action
  # result = []
  # for i in range(len(action)):
  #   result.append((action[i] - (-6.284)) / (6.284 - (-6.284)))
  # return result


In [None]:
ACTION_HORIZON = 1
FREQUENCY_REDUCTION = 25

In [None]:
train_data = episodes_data([21, 23, 24, 25, 26, 27, 30, 31, 32], FREQUENCY_REDUCTION, ACTION_HORIZON)
val_data = episodes_data([28, 29], FREQUENCY_REDUCTION, ACTION_HORIZON)

In [None]:
train_dataset = CustomImageDataset(train_data, image_transformator, action_transformator)
val_dataset = CustomImageDataset(val_data, test_image_transformator, action_transformator)

In [None]:
model = CVAE(ACTION_HORIZON)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
train_loss_list = []
val_loss_list = []

In [None]:
model = train_visuomotor_policy(model, train_dataset, 0.00001,
                                batch_size=8, num_of_epochs=50,
                                device=device, shuffle=True,
                                validation_set=val_dataset,
                                loss_list=train_loss_list,
                                val_loss_list=val_loss_list)

In [None]:
model = train_visuomotor_policy(model, train_dataset, 0.0000001,
                                     batch_size=8, num_of_epochs=10,
                                     device=device, shuffle=True,
                                     validation_set=val_dataset,
                                     loss_list=train_loss_list,
                                     val_loss_list=val_loss_list)

In [None]:
model = train_visuomotor_policy(model, train_dataset, 0.000000001,
                                     batch_size=8, num_of_epochs=10,
                                     device=device, shuffle=True,
                                     validation_set=val_dataset,
                                     loss_list=train_loss_list,
                                     val_loss_list=val_loss_list)

In [None]:
TRAINED_MODEL = '/comparation_cvae_v1_1'

In [None]:
torch.save(model, DATA_PATH + '/models' + TRAINED_MODEL)

**Model testing**

In [None]:
def test_resnet(model, dataset, device):
  criterion = nn.L1Loss(reduction="mean")
  dataloader = torch.utils.data.DataLoader(
      dataset=dataset, batch_size=1, shuffle=False)

  diffs = []

  model = model.to(device)
  model.eval()
  loss_sum = 0
  batches_num = 0
  for data in dataloader:
    batches_num += 1
    images, labels, currents = data
    images = images.to(device)
    labels = labels.to(device)
    currents = currents.to(device)

    z = torch.zeros(1, model.hidden_size).to(device)
    prediction = model.decode(z, images, currents)

    loss = criterion(prediction, labels)

    diff = torch.abs(prediction - labels)
    diffs.append(diff.max().item())

    loss_sum += loss.item()
  epoch_loss = loss_sum / batches_num
  return epoch_loss, diffs

In [None]:
model.eval()

In [None]:
test_data = episodes_data([19, 22], FREQUENCY_REDUCTION, ACTION_HORIZON)
test_d = CustomImageDataset(test_data, test_image_transformator, action_transformator)

In [None]:
model.eval()
metric, diffs = test_resnet(model, test_d, device)
print('res= ', metric)
print(sorted(diffs)[0:5])
print(sorted(diffs)[-5:])

In [None]:
test_data = episodes_data([22], FREQUENCY_REDUCTION, ACTION_HORIZON)
test_d = CustomImageDataset(test_data, test_image_transformator, action_transformator)

In [None]:
model.eval()
metric, diffs = test_resnet(model, test_d, device)
print('res= ', metric)
print(sorted(diffs)[0:5])
print(sorted(diffs)[-5:])

**Plots**

In [None]:
fig, ax = plt.subplots()

ax.plot(train_loss_list[1:], label='train', color='maroon')
ax.plot(val_loss_list[1:], label='validation', color='green')

ax.set_xlabel('epoch')
ax.set_ylabel('loss')

ax.legend()

plt.show()