In [1]:
# -*- coding: utf-8 -*-
import csv
import tqdm
import copy
import click
import logging
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset

from string import digits

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torch.nn.functional as F

seed = 42
epochs = 200
batch_size = 128
learning_rate = 1e-3
context_frames = 10
sequence_length = 16
lookback = sequence_length

context_epochs = 20
context_batch_size = 1
context_learning_rate = 1e-3
context_data_length = 20

valid_train_split = 0.8  # precentage of train data from total
test_train_split = 0.9  # precentage of train data from total

torch.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")#  use gpu if available
################################# CHANGE THIS!!!!  #################################
model_path = "/home/user/Robotics/slip_detection_model/slip_detection_model/manual_data_models/models/conv_model_001/"
################################# CHANGE THIS!!!!  #################################


In [2]:
class BatchGenerator:
    def __init__(self, data_dir):
        self.data_dir = data_dir
        data_map = []
        with open(data_dir + 'map.csv', 'r') as f:  # rb
            reader = csv.reader(f)
            for row in reader:
                data_map.append(row)

        if len(data_map) <= 1: # empty or only header
            print("No file map found")
            exit()

        self.data_map = data_map

    def load_full_data(self):
        dataset_train = FullDataSet(self.data_dir, self.data_map, type_="train")
        dataset_valid = FullDataSet(self.data_dir, self.data_map, type_="valid")
        dataset_test = FullDataSet(self.data_dir, self.data_map, type_="test")
        transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
        train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
        valid_loader = torch.utils.data.DataLoader(dataset_valid, batch_size=batch_size, shuffle=True)
        test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=batch_size, shuffle=False)
        return train_loader, valid_loader, test_loader

class FullDataSet():
    def __init__(self, data_dir, data_map, type_="train"):
        if type_ == "train":
            self.samples = data_map[1:int(len(data_map)*test_train_split)]
        elif type_ == "valid":
            self.samples = data_map[int(len(data_map)*(valid_train_split)):int(len(data_map)*test_train_split)]
        elif type_ == "test":
            self.samples = data_map[int(len(data_map)*test_train_split):-1]
        data_map = None

    def __len__(self):
        return len(self.samples)

    def __getitem__(self,idx):
        value = self.samples[idx]
        robot = np.load(data_dir + value[0])
        xela1image = np.load(data_dir + value[3])
        experiment = np.load(data_dir + value[-2])
        time_step  = np.load(data_dir + value[-1])     
        return([robot.astype(np.float32),
                             xela1image.astype(np.float32),
                             experiment,
                             time_step])

In [3]:
class ConvLSTMCell(nn.Module):
    def __init__(self, input_dim, hidden_dim, kernel_size, bias):
        super(ConvLSTMCell, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.kernel_size = kernel_size
        self.padding = kernel_size[0] // 2, kernel_size[1] // 2
        self.bias = bias
        self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim, out_channels=4 * self.hidden_dim,
                              kernel_size=self.kernel_size, padding=self.padding, bias=self.bias)

    def forward(self, input_tensor, cur_state):
        h_cur, c_cur = cur_state
        combined = torch.cat([input_tensor, h_cur], dim=1)  # concatenate along channel axis
        combined_conv = self.conv(combined)
        cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
        i = torch.sigmoid(cc_i)
        f = torch.sigmoid(cc_f)
        o = torch.sigmoid(cc_o)
        g = torch.tanh(cc_g)
        c_next = f * c_cur + i * g
        h_next = o * torch.tanh(c_next)
        return h_next, c_next

    def init_hidden(self, batch_size, image_size):
        height, width = image_size
        return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device).to(device),
                torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device).to(device))

class FullModel(nn.Module):
    def __init__(self):
        super(FullModel, self).__init__()
        self.convlstm1 = ConvLSTMCell(input_dim=3, hidden_dim=12, kernel_size=(5, 5), bias=True).cuda()
        self.convlstm2 = ConvLSTMCell(input_dim=24, hidden_dim=3, kernel_size=(5, 5), bias=True).cuda()
#         self.conv1     = nn.Conv3d(in_channels=24, out_channels=3, kernel_size=5, 
#                                    stride=1, padding=2, bias=True)

    def tile(self, state_action):
        tiled_state_action = torch.zeros(self.batch_size, 12, 32, 32)
        for batch in range(self.batch_size):
            for index, feature in enumerate(state_action[batch]):
                tiled_state_action[batch][index][:] = feature
        return tiled_state_action

    def forward(self, tactiles, actions):
        self.batch_size = actions.shape[1]
        state = actions[0]
        state.to(device)
        batch_size__ = tactiles.shape[1]
        hidden_1, cell_1 = self.convlstm1.init_hidden(batch_size=self.batch_size, image_size=(32, 32))
        hidden_2, cell_2 = self.convlstm2.init_hidden(batch_size=self.batch_size, image_size=(32, 32))
        outputs = []
        for index, (sample_tactile, sample_action) in enumerate(zip(tactiles.squeeze(), actions.squeeze())):
            sample_tactile.to(device)
            sample_action.to(device)
            # 2. Run through lstm:
            if index > context_frames-1:
                hidden_1, cell_1 = self.convlstm1(input_tensor=cell_2, cur_state=[hidden_1, cell_1])
                state_action = torch.cat((state, sample_action), 1)
                state_action_tile = self.tile(state_action).to(device)
                robot_and_tactile = torch.cat((state_action_tile.squeeze(), cell_1.squeeze()), 1)
                hidden_2, cell_2 = self.convlstm2(input_tensor=robot_and_tactile, cur_state=[hidden_2, cell_2])
                outputs.append(cell_2)
            else:
                hidden_1, cell_1 = self.convlstm1(input_tensor=sample_tactile, cur_state=[hidden_1, cell_1])
                state_action = torch.cat((state, sample_action), 1)
                state_action_tile = self.tile(state_action).to(device)
                robot_and_tactile = torch.cat((state_action_tile.squeeze(), cell_1.squeeze()), 1)
                hidden_2, cell_2 = self.convlstm2(input_tensor=robot_and_tactile, cur_state=[hidden_2, cell_2])
        return torch.stack(outputs)

In [7]:
class ModelTrainer:
    def __init__(self, data_dir):
        self.data_dir = data_dir
        self.train_full_loader, self.valid_full_loader, self.test_full_loader = BG.load_full_data()
        self.full_model = FullModel()
        self.criterion = nn.L1Loss()
        self.optimizer = optim.Adam(self.full_model.parameters(), lr=learning_rate)

    def train_full_model(self):
        best_model_loss_score = 10.0 
        self.plot_training_loss = []
        self.plot_validation_loss = []
        previous_val_mean_loss = 1.0
        early_stop_clock = 0
        progress_bar = tqdm.tqdm(range(0, epochs), total=(epochs*len(self.train_full_loader)))
        mean_test = 0
        for epoch in progress_bar:
            loss = 0
            losses = 0.0
            for index, batch_features in enumerate(self.train_full_loader):
                action  = batch_features[0].permute(1,0,2).to(device)
                tactile = batch_features[1].permute(1,0,4,2,3).to(device)
                tactile_predictions = self.full_model.forward(tactiles=tactile, actions=action) # Step 3. Run our forward pass.
                self.optimizer.zero_grad()
                loss = self.criterion(tactile_predictions.to(device), tactile[context_frames:])
                loss.backward()
                self.optimizer.step()

                losses += loss.item()
                if index:
                    mean = losses / index
                else:
                    mean = 0
                progress_bar.set_description("epoch: {}, ".format(epoch) + "loss: {:.4f}, ".format(float(loss.item())) + "mean loss: {:.4f}, ".format(mean))
                progress_bar.update()
                self.plot_training_loss.append(mean)

            val_losses = 0.0
            val_loss = 0.0
            with torch.no_grad():
                for index__, batch_features in enumerate(self.valid_full_loader):
                    action = batch_features[0].permute(1,0,2).to(device)
                    tactile = batch_features[1].permute(1,0,4,2,3).to(device)
                    tactile_predictions = self.full_model.forward(tactiles=tactile, actions=action)  # Step 3. Run our forward pass.
                    self.optimizer.zero_grad()
                    val_loss = self.criterion(tactile_predictions.to(device), tactile[context_frames:])
                    val_losses += val_loss.item()

            print("Validation mean loss: {:.4f}, ".format(val_losses / index__))
            self.plot_validation_loss.append(val_losses / index__)
            if previous_val_mean_loss < val_losses / index__:
                early_stop_clock +=1
                previous_val_mean_loss = val_losses / index__ 
                if early_stop_clock == 3:
                    print("Early stopping")
                    break
            else:
                if (val_losses / index__) < best_model_loss_score:
                    self.strongest_model = copy.deepcopy(self.full_model)
                early_stop_clock = 0
                previous_val_mean_loss = val_losses / index__
        plt.plot(self.plot_training_loss, c="r", label="train loss MAE")
        plt.plot(self.plot_validation_loss, c='b', label="val loss MAE")
        plt.legend(loc="upper right")
        plt.show()
        plt.savefig(model_path + '/model_training_plot.png', dpi=300)
        np.save(model_path + 'training_loss', np.asarray(self.plot_training_loss))
        np.save(model_path + 'validation_loss', np.asarray(self.plot_validation_loss))

In [8]:
data_dir = '/home/user/Robotics/Data_sets/slip_detection/manual_slip_detection/'
BG = BatchGenerator(data_dir)
print("done")

done


In [None]:
MT = ModelTrainer(data_dir)
MT.train_full_model()
print("finished training")
torch.save(MT.strongest_model, model_path + "full_model")
model = torch.load(model_path + "full_model")
model.eval()
print("saved the model")

epoch: 0, loss: 0.0255, mean loss: 0.0455, :   1%|          | 712/142200 [23:32<1684:55:09, 42.87s/it]

Validation mean loss: 0.0245, 


epoch: 1, loss: 0.0240, mean loss: 0.0253, :   1%|          | 1423/142200 [46:27<89:38:04,  2.29s/it] 

Validation mean loss: 0.0214, 


epoch: 2, loss: 0.0275, mean loss: 0.0240, :   1%|          | 1545/142200 [52:41<64:57:19,  1.66s/it]  