<a href="https://colab.research.google.com/github/ItoMasaki/Notebook-List/blob/main/GraspWithYOLOv1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [190]:
%load_ext tensorboard

In [180]:
!pip3 install pytorch_lightning



In [201]:
import torch
from torch.nn import functional as F
from torch.utils.data import random_split
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers

from glob import glob

import re

from skimage.io import imread

from PIL import Image, ImageDraw

from numpy import array

from random import gauss

In [182]:
def atoi(text):    
    return int(text) if text.isdigit() else text    
    
def natural_keys(text):    
    return [ atoi(c) for c in re.split(r'(\d+)', text) ]

In [183]:
class LoadData():
  def __init__(self):
    T_annotations_path = "/content/drive/MyDrive/datasets/image_based_grasp_point/Annotations/*"
    T_images_path = "/content/drive/MyDrive/datasets/image_based_grasp_point/Images/*"
    
    T_annotations_pathes = sorted(glob(T_annotations_path), key=natural_keys)
    T_images_pathes = sorted(glob(T_images_path), key=natural_keys)

    self.X = []
    self.Y = []

    self.T_label_data = []
    self.T_image_data = []
    self.T_distr_data = []

    for idx, annotation_path in enumerate(T_annotations_pathes):
      tmp_image = array(Image.fromarray(imread(T_images_pathes[idx])))
      shape = tmp_image.shape

      image = array(Image.fromarray(imread(T_images_pathes[idx])).resize((128, 128)))/255

      with open(annotation_path, "r") as f:
        data = f.readlines()

        if len(data) == 3:
          positions = data[0:2]
          dist = eval(data[2])
          p_1 = [int(p)/shape[i] for i, p in enumerate(positions[0].replace("\n", "").split(" "))]
          p_2 = [int(p)/shape[i] for i, p in enumerate(positions[1].replace("\n", "").split(" "))]
          self.T_label_data.append(array(p_1 + p_2))
          self.T_image_data.append(image.transpose(2, 0, 1))
          self.T_distr_data.append(array(dist))

          for n in range(20):
            p_1 = [gauss(int(p)/shape[i], 0.01) for i, p in enumerate(positions[0].replace("\n", "").split(" "))]
            p_2 = [gauss(int(p)/shape[i], 0.01) for i, p in enumerate(positions[1].replace("\n", "").split(" "))]
            self.T_label_data.append(array(p_1 + p_2))
            self.T_image_data.append(image.transpose(2, 0, 1))
            self.T_distr_data.append(array(dist))

  def __getitem__(self, index):
    return [self.T_image_data[index], self.T_distr_data[index]], self.T_label_data[index]

  def __len__(self):
    return len(self.T_label_data)

load_data = LoadData()

In [184]:
class GraspingPoinntDataModule(pl.LightningDataModule):    
    def __init__(self, path, batch_size):    
        super().__init__()    
    
        self.path = path    
        self.batch_size = batch_size    
    
    def setup(self, stage=None):    
        self.grasping_point_train = LoadData()    
        #self.grasping_point_test = ExpertDataset(self.path, validation_flag=True)    
        #self.grasping_point_val = ExpertDataset(self.path, validation_flag=True)    
        # self.grasping_point_train, self.grasping_point_val = random_split(grasping_point_full, [len(grasping_point_full)-val_num, val_num])    
    
    def train_dataloader(self):    
        return DataLoader(self.grasping_point_train, batch_size=self.batch_size, num_workers=16, shuffle=True)    
    
    #def val_dataloader(self):    
    #    return DataLoader(self.grasping_point_val, batch_size=len(self.grasping_point_val), num_workers=16, shuffle=False)    
    
    #def test_dataloader(self):    
    #    return DataLoader(self.grasping_point_test, batch_size=len(self.grasping_point_val), shuffle=False)

In [185]:
import torch    
from torch import nn    
    
class ExtractFeature(nn.Module):    
    def __init__(self, in_channels=4, out_channels=32, kernel_size=3):    
        super().__init__()    
    
        self.up_channel = nn.Conv2d(3, 4, 4, 2, 1)    
    
        self.backborn = nn.Sequential(    
                nn.Conv2d(in_channels, in_channels*2, kernel_size, 2, 1),    
                nn.ReLU(),    
                nn.Conv2d(in_channels*2, in_channels*3, kernel_size, 2, 1),    
                nn.ReLU(),    
                nn.Conv2d(in_channels*3, in_channels*4, kernel_size, 2, 1),    
                nn.ReLU(),    
                nn.Conv2d(in_channels*4, in_channels*5, kernel_size, 2, 1),    
                nn.ReLU(),    
                nn.Conv2d(in_channels*5, in_channels*6, kernel_size, 2, 1),    
                nn.ReLU(),    
                nn.Conv2d(in_channels*6, in_channels*7, kernel_size, 2, 1),    
        )    
    
    def forward(self, x):    
        x = self.up_channel(x)    
        x = self.backborn(x)    
        return x

In [186]:
import torch    
from torch import nn    
from torch.utils.tensorboard import SummaryWriter    
    
    
class Encoder(nn.Module):    
    def __init__(self, in_channels=28+16, out_channels=32):    
        super().__init__()    
    
        self.encode = nn.Sequential(    
                nn.Linear(in_channels, 64),        
                nn.SiLU(),        
                nn.Linear(64, 56),    
                nn.SiLU(),    
                nn.Linear(56, 48),    
                nn.SiLU(),    
                nn.Linear(48, 40),    
                nn.SiLU(),    
                nn.Linear(40, 32),    
                nn.SiLU(),    
        )

    def forward(self, x):    
        x = self.encode(x)

        return x

In [187]:
class Decoder(nn.Module):
    def __init__(self, in_channels=28+16, out_channels=5):    
        super().__init__()    

        self.decode = nn.Sequential(
                nn.Linear(32, 16),
                nn.SiLU(),
                nn.Linear(16, 8),
                nn.SiLU(),
        )

        self.decode_pos_1 = nn.Sequential(
                nn.Linear(8, 6),
                nn.ReLU(),
                nn.Linear(6, 4),
                nn.ReLU(),
                nn.Linear(4, 2),
                nn.Sigmoid()
        )

        self.decode_pos_2 = nn.Sequential(    
                nn.Linear(8, 6),
                nn.ReLU(),
                nn.Linear(6, 4),
                nn.ReLU(),
                nn.Linear(4, 2),
                nn.Sigmoid()
        )

    def forward(self, x):
        x = self.decode(x)

        pos_1 = self.decode_pos_1(x)
        pos_2 = self.decode_pos_2(x)

        x = torch.cat((pos_1, pos_2), 1)
        return x

In [198]:
class Network(pl.LightningModule):
  def __init__(self):
    super().__init__()

    self.writer = SummaryWriter(log_dir="/content/drive/MyDrive/")

    self.extract_feature = ExtractFeature()
    self.linear_transformation = nn.Linear(11, 16)
    self.flatten = nn.Flatten()
    self.encode = Encoder()
    self.decoder = Decoder()

  def forward(self, image, dist):    
    feature_from_image =  self.flatten(self.extract_feature(image.float()))    
    feature_from_yolo = self.linear_transformation(dist.float())    
    feature = torch.cat((feature_from_image, feature_from_yolo), 1)
    x = self.encode(feature.float())
    x_hat = self.decoder(x)    
    return x_hat

  def training_step(self, batch, batch_idx):
    x, y = batch
    image, dist = x

    x_hat = self(image, dist)

    loss = F.mse_loss(x_hat[:, 0:4], y.float()[:, 0:4])

    self.log('loss/train', loss)

    return loss

  def configure_optimizers(self):
    optimizer = torch.optim.AdamW(self.parameters(), lr=1e-4)
    return optimizer

In [202]:
# from models.GraspingNetwork import Network    

batch_size = 128
max_epochs = 300
gpus = 0
checkpoint_path = "/content/lightning_logs/version_10/checkpoints/epoch=210-step=17090.ckpt"
    
def main():
    tb_logger = pl_loggers.TensorBoardLogger("/content/drive/MyDrive/logs")
    GPDataset = GraspingPoinntDataModule("data/my_research", batch_size)
    model = Network()

    trainer = pl.Trainer(gpus=gpus, max_epochs=max_epochs, logger=tb_logger)

    if checkpoint_path is not None:
      trainer.fit(model, GPDataset, ckpt_path=checkpoint_path)
    else:
      trainer.fit(model, GPDataset)

main()

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
Restoring states from the checkpoint path at /content/lightning_logs/version_10/checkpoints/epoch=210-step=17090.ckpt
Missing logger folder: /content/drive/MyDrive/logs/default
  "You're resuming from a checkpoint that ended mid-epoch."
Restored all states from the checkpoint file at /content/lightning_logs/version_10/checkpoints/epoch=210-step=17090.ckpt

  | Name                  | Type           | Params
---------------------------------------------------------
0 | extract_feature       | ExtractFeature | 16.4 K
1 | linear_transformation | Linear         | 192   
2 | flatten               | Flatten        | 0     
3 | encode                | Encoder        | 12.5 K
4 | decoder               | Decoder        | 848   
---------------------------------------------------------
30.0 K    Trainable params
0         Non-trainable params
30.0 K    Total params
0.120     Total estim

Training: 0it [00:00, ?it/s]