In [1]:
import os
import random
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

class GripperMetalSheetDataset(Dataset):
    def __init__(self, gripper_dir, metal_sheet_dir, transform=None):
        self.gripper_dir = gripper_dir
        self.metal_sheet_dir = metal_sheet_dir
        self.transform = transform
        self.gripper_files = [f for f in os.listdir(gripper_dir) if f.endswith('.png')]
        self.metal_sheet_files = [f for f in os.listdir(metal_sheet_dir) if f.endswith('.png')]

    def __len__(self):
        return len(self.gripper_files)

    def __getitem__(self, idx):
        gripper_path = os.path.join(self.gripper_dir, self.gripper_files[idx])
        metal_sheet_path = os.path.join(self.metal_sheet_dir, self.metal_sheet_files[idx])
        
        gripper_image = Image.open(gripper_path).convert("RGB")
        metal_sheet_image = Image.open(metal_sheet_path).convert("RGB")
        
        # Generate random transformations for training
        shift_x = random.randint(0, 50)
        shift_y = random.randint(0, 50)
        rotation = random.randint(0, 360)
        
        if self.transform:
            gripper_image = self.transform(gripper_image)
            metal_sheet_image = self.transform(metal_sheet_image)
        
        label = torch.tensor([shift_x, shift_y, rotation], dtype=torch.float32)
        
        return gripper_image, metal_sheet_image, label

transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor()
])

dataset = GripperMetalSheetDataset('data/train_images/Grippers', 'data/train_images/Metal_sheets', transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch.nn as nn
import torch.nn.functional as F

class GripperMetalSheetModel(nn.Module):
    def __init__(self):
        super(GripperMetalSheetModel, self).__init__()
        self.gripper_conv = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.metal_sheet_conv = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.fc = nn.Sequential(
            nn.Linear(32 * 64 * 64 * 2, 512),
            nn.ReLU(),
            nn.Linear(512, 3)  # Output: shift_x, shift_y, rotation
        )

    def forward(self, gripper_image, metal_sheet_image):
        gripper_features = self.gripper_conv(gripper_image)
        metal_sheet_features = self.metal_sheet_conv(metal_sheet_image)
        
        gripper_features = gripper_features.view(gripper_features.size(0), -1)
        metal_sheet_features = metal_sheet_features.view(metal_sheet_features.size(0), -1)
        
        combined_features = torch.cat((gripper_features, metal_sheet_features), dim=1)
        output = self.fc(combined_features)
        
        return output

model = GripperMetalSheetModel()

In [8]:
import torch.optim as optim

# Define loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for gripper_image, metal_sheet_image, labels in dataloader:
        optimizer.zero_grad()
        outputs = model(gripper_image, metal_sheet_image)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(dataloader):.4f}")

print("Training complete.")

Epoch [1/10], Loss: 17687.3711
Epoch [2/10], Loss: 6133.6260
Epoch [3/10], Loss: 3372.4402
Epoch [4/10], Loss: 4456.1021
Epoch [5/10], Loss: 3710.9297
Epoch [6/10], Loss: 1862.4824
Epoch [7/10], Loss: 5660.7207
Epoch [8/10], Loss: 4746.9907
Epoch [9/10], Loss: 2905.3569
Epoch [10/10], Loss: 6602.7212
Training complete.
