In [None]:
import os
import torch
import torchvision
from torch import nn
from torch import optim
from torch.utils.data import Dataset
from torchvision import datasets, transforms
from torchvision.io import read_image
from PIL import Image

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
DATA_PATH = '/content/drive/MyDrive/diploma'
DATASET_DIR = '/dataset-v3'
IMGS_DIR = '/imgs'
IMGS_PATH = DATA_PATH + DATASET_DIR + IMGS_DIR

In [None]:
import json

def parse_row_v2(row):
  devided = row.split(" --&#!--- ")
  target_part = devided[-1]
  target = json.loads(target_part)
  return target


def parse_actions_file(rows):
  actions = []
  for row in rows:
    if not " --&#!--- " in row:
      continue
    actions.append(parse_row_v2(row))
  return actions


def reformat_actions_into_dict(parsed_rows):
  result = {}
  for row in parsed_rows:
    result[row['t']] = row
  return result


def episode_actions_dict(episode_number):
  with open(DATA_PATH + DATASET_DIR + "/" + str(episode_number), "r") as f:
    raw_data = f.readlines()
  parsed_actions = parse_actions_file(raw_data)
  actions_dict = reformat_actions_into_dict(parsed_actions)
  return actions_dict


def episode_images_list(episode_number):
  return os.listdir(IMGS_PATH + str(episode_number))


def merge_actions_and_images(actions, images_list, images_path):
  SIGNS_TO_COMPARE = 11
  def to_short_time(t):
    return int(str(t)[:SIGNS_TO_COMPARE])

  result = {}
  actions_count = len(actions)
  images_count = len(images_list)
  images_short = {}

  for i in images_list:
    key = to_short_time(i)
    if not key in images_short:
      images_short[key] = i

  for action_time in actions.keys():
    short_time = to_short_time(action_time)
    if (short_time in images_short):
      the_item = actions[action_time]
      the_item['img'] = images_path + '/' + images_short[short_time]
      result[short_time] = the_item
  return result


def reduce_frequency_by_step(episode_data, step):
  new_episode_data = {}
  sorted_keys = sorted(episode_data.keys())
  for i in range(step, len(sorted_keys), 1):
    new_elem = episode_data[sorted_keys[i-step]].copy()
    next_elem = episode_data[sorted_keys[i]]
    new_elem['t'] = next_elem['t']
    new_elem['j'] = next_elem['j']
    new_episode_data[sorted_keys[i - step]] = new_elem
  for i in range(-step, 0, 1):
    new_elem = episode_data[sorted_keys[i]].copy()
    next_elem = episode_data[sorted_keys[-1]]
    new_elem['t'] = next_elem['t']
    new_elem['j'] = next_elem['j']
    new_episode_data[sorted_keys[i]] = new_elem
  return new_episode_data


def with_action_horizon(reduced, action_horizon):
  reduced_keys = sorted(list(reduced.keys()))
  len_reduced = len(reduced)
  result = {}

  for i in range(len_reduced):
    to_append_to_current = []
    current = reduced[reduced_keys[i]]['j'].copy()

    for j in range(1, action_horizon):
      to_append = current.copy()
      if (i + j) < len_reduced:
        to_append = reduced[reduced_keys[i + j]]['j'].copy()

      to_append_to_current += to_append.copy()

    result[reduced_keys[i]] = {
        'j': reduced[reduced_keys[i]]['j'].copy() + to_append_to_current.copy(),
        'current': reduced[reduced_keys[i]]['current'].copy(),
        't': reduced[reduced_keys[i]]['t'],
        'img': reduced[reduced_keys[i]]['img']
        }
  return result

def form_data_for_episode(episode_number, reduction_step, action_horizon):
  actions_dict = episode_actions_dict(episode_number)
  images_list = episode_images_list(episode_number)
  episode_data = merge_actions_and_images(actions_dict, sorted(images_list), IMGS_PATH + str(episode_number))
  reduced_data = reduce_frequency_by_step(episode_data, reduction_step)
  data_with_action_horizon = with_action_horizon(reduced_data, action_horizon)
  return data_with_action_horizon


def episodes_data(numbers_of_episodes, reduction_step, action_horizon):
  result = {}
  for i in numbers_of_episodes:
    result.update(form_data_for_episode(i, reduction_step, action_horizon))
  return result

In [None]:
class CustomImageDataset(Dataset):
    def __init__(self, data, image_transform, action_transform):
        self.data = list(data.values())
        self.image_transform = image_transform
        self.action_transform = action_transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]

        image = Image.open(sample['img'])
        image = self.image_transform(image)

        action = self.action_transform(sample['j'])
        current = self.action_transform(sample['current'])
        return image, torch.tensor(action), torch.tensor(current)

In [None]:
class DenceNN(nn.Module):
  def __init__(self, action_horizon):
    super().__init__()
    self.dence_layer_1 = nn.Sequential(
        nn.Linear(1128, 512),
        nn.ReLU(),
        nn.Linear(512, 256),
        nn.ReLU(),
        nn.Linear(256, 128),
        nn.ReLU(),
        nn.Linear(128, 64),
        nn.ReLU()
    )
    self.dence_head = nn.Sequential(
        nn.Linear(64, action_horizon * 7)
    )
  def forward(self, x):
    result = self.dence_layer_1(x)
    result = self.dence_head(result)
    return result


class ResnetVisiomotorPolicy(nn.Module):
  def __init__(self, action_horizon, pretrained_resnet):
    super().__init__()

    feature_exctractor = None
    if pretrained_resnet:
      feature_exctractor = torchvision.models.resnet34(weights=torchvision.models.ResNet34_Weights, progress=False)
    else:
      feature_exctractor = torchvision.models.resnet34(progress=False)

    # if freeze_resnet:
    #   for param in feature_exctractor.parameters():
    #     param.requires_grad = False

    assert(feature_exctractor)
    self.perception_network = feature_exctractor
    self.policy_network = DenceNN(action_horizon)
    self.joint_space_projection = nn.Linear(7, 128)

  def forward(self, image, current_state):
    visual_repr = self.perception_network(image)
    joints_repr = self.joint_space_projection(current_state)
    concat_repr = torch.cat((visual_repr, joints_repr), -1)
    return self.policy_network(concat_repr)

In [None]:
def train_visuomotor_policy(
    model, dataset, learning_rate, batch_size, num_of_epochs, device, shuffle,
    validation_set, loss_list, val_loss_list):

  optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
  criterion = nn.L1Loss(reduction="mean")
  dataloader = torch.utils.data.DataLoader(
      dataset=dataset, batch_size=batch_size, shuffle=shuffle)
  validation_loader = torch.utils.data.DataLoader(
      dataset=validation_set, batch_size=batch_size, shuffle=True)

  model = model.to(device)
  model.train()

  for epoch in range(num_of_epochs):
    loss_sum = 0
    batches_count = 0
    for data in dataloader:
      batches_count += 1
      images, labels, currents = data
      images = images.to(device)
      labels = labels.to(device)
      currents = currents.to(device)

      prediction = model(images, currents)

      loss = criterion(prediction, labels)

      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      loss_sum += loss.item()

    epoch_loss = loss_sum / batches_count
    loss_list.append(epoch_loss)

    model.eval()
    v_loss = 0
    with torch.no_grad():
      val_criterion = nn.L1Loss(reduction="mean")
      val_batches_count = 0
      for val_data in validation_loader:
        val_batches_count += 1
        v_images, v_labels, v_currents = val_data
        v_images = v_images.to(device)
        v_labels = v_labels.to(device)
        v_currents = v_currents.to(device)

        v_prediction = model(v_images, v_currents)

        v_loss += val_criterion(v_prediction, v_labels).item()

      val_loss = v_loss / val_batches_count
      val_loss_list.append(val_loss)

    model.train()

    print(f'epoch \t{epoch}\t -- train loss: {epoch_loss} -- val loss: {val_loss}')

  return model

In [None]:
image_transformator = transforms.Compose([
    # transforms.Grayscale(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    # transforms.ColorJitter(0.4, 0.4, 0.0, 0.3),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

test_image_transformator = transforms.Compose([
    # transforms.Grayscale(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

def action_transformator(action):
  return action
  # result = []
  # for i in range(len(action)):
  #   result.append((action[i] - (-6.284)) / (6.284 - (-6.284)))
  # return result


In [None]:
ACTION_HORIZON = 1
FREQUENCY_REDUCTION = 25

In [None]:
train_data = episodes_data([21, 23, 24, 25, 26, 27, 30, 31, 32], FREQUENCY_REDUCTION, ACTION_HORIZON)
val_data = episodes_data([28, 29], FREQUENCY_REDUCTION, ACTION_HORIZON)

In [None]:
train_dataset = CustomImageDataset(train_data, image_transformator, action_transformator)
val_dataset = CustomImageDataset(val_data, test_image_transformator, action_transformator)

In [None]:
model = ResnetVisiomotorPolicy(ACTION_HORIZON, True)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
train_loss_list = []
val_loss_list = []

In [None]:
model = train_visuomotor_policy(model, train_dataset, 0.00001,
                                batch_size=8, num_of_epochs=50,
                                device=device, shuffle=True,
                                validation_set=val_dataset,
                                loss_list=train_loss_list,
                                val_loss_list=val_loss_list)

In [None]:
model = train_visuomotor_policy(model, train_dataset, 0.0000001,
                                     batch_size=8, num_of_epochs=10,
                                     device=device, shuffle=True,
                                     validation_set=val_dataset,
                                     loss_list=train_loss_list,
                                     val_loss_list=val_loss_list)

In [None]:
model = train_visuomotor_policy(model, train_dataset, 0.000000001,
                                     batch_size=8, num_of_epochs=10,
                                     device=device, shuffle=True,
                                     validation_set=val_dataset,
                                     loss_list=train_loss_list,
                                     val_loss_list=val_loss_list)

In [None]:
TRAINED_MODEL = '/comparation_resnet_model_v1_1'

In [None]:
torch.save(model, DATA_PATH + '/models' + TRAINED_MODEL)

**Model testing**

In [None]:
def test_resnet(model, dataset, device):
  criterion = nn.L1Loss(reduction="mean")
  dataloader = torch.utils.data.DataLoader(
      dataset=dataset, batch_size=1, shuffle=False)

  diffs = []

  model = model.to(device)
  model.eval()
  loss_sum = 0
  batches_num = 0
  for data in dataloader:
    batches_num += 1
    images, labels, currents = data
    images = images.to(device)
    labels = labels.to(device)
    currents = currents.to(device)

    prediction = model(images, currents)

    loss = criterion(prediction, labels)

    diff = torch.abs(prediction - labels)
    diffs.append(diff.max().item())

    loss_sum += loss.item()
  epoch_loss = loss_sum / batches_num
  return epoch_loss, diffs

In [None]:
model = torch.load(DATA_PATH + '/models' + '/comparation_resnet_model_v1_1', map_location=device)

In [None]:
model.eval()

In [None]:
test_data = episodes_data([19, 22], FREQUENCY_REDUCTION, ACTION_HORIZON)
test_d = CustomImageDataset(test_data, test_image_transformator, action_transformator)

In [None]:
model.eval()
metric, diffs = test_resnet(model, test_d, device)
print('res= ', metric)
print(sorted(diffs)[0:5])
print(sorted(diffs)[-5:])

In [None]:
test_data = episodes_data([22], FREQUENCY_REDUCTION, ACTION_HORIZON)
test_d = CustomImageDataset(test_data, test_image_transformator, action_transformator)

In [None]:
model.eval()
metric, diffs = test_resnet(model, test_d, device)
print('res= ', metric)
print(sorted(diffs)[0:5])
print(sorted(diffs)[-5:])

**Plots**

In [None]:
fig, ax = plt.subplots()

ax.plot(train_loss_list[1:], label='train', color='maroon')
ax.plot(val_loss_list[1:], label='validation', color='green')

ax.set_xlabel('epoch')
ax.set_ylabel('loss')

ax.legend()

plt.show()