In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import json
import os, json, cv2, numpy as np, matplotlib.pyplot as plt, yaml
import pandas as pd
import torchvision.models as models
import torch.nn as nn

In [2]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

def load_class_info(yaml_file):
    with open(yaml_file, 'r') as file:
            class_info = yaml.safe_load(file)
    return class_info['classes']

In [3]:
class ClassDataset(Dataset):
    def __init__(self, dataset_folder, class_info_file, transform=None, demo=False):
        self.dataset_folder = dataset_folder
        self.transform = transform
        # self.label_transform = label_transform
        self.demo = demo
        self.imgs_files = self.load_data(dataset_folder)
        self.class_names = load_class_info(class_info_file)
        self.num_classes = len(self.class_names)
        # self.normalize_keypoints = normalize_keypoints
        self.class_to_idx = {class_name: idx for idx, class_name in enumerate(self.class_names)}
        print(self.class_to_idx)

    def load_data(self, dataset_folder):
        images_path = os.path.join(self.dataset_folder,"images/")
        annotations_path = os.path.join(self.dataset_folder,"annotations/")
        j_data = []
        for file in os.listdir(images_path):
            if file.endswith(".jpg"):
                json_path = os.path.join(annotations_path, file.split('.')[0] + '.json')
                with open(json_path) as f:
                    json_load = json.load(f)
                    for item in json_load['shapes']:
                        points = [value for row in item['points'] for value in row]
                        j_data.append({'image':  os.path.join(images_path,file),
                                     'label': item['label'],
                                     'points':points })
        json_data = pd.DataFrame(j_data)
        return json_data

    def get_keypoint(self, bboxes):
        centers = []
        for bbox in bboxes:
            center_x = (bbox[0] + bbox[2]) / 2
            center_y = (bbox[1] + bbox[3]) / 2
            centers.append((center_x, center_y))
        return centers
    
    def __len__(self):
        return len(self.imgs_files)

    def __getitem__(self, idx):
        target = {}
        img_path, label, bboxes_original = self.imgs_files.iloc[idx]
        bboxes_original = [bboxes_original]
        label = torch.tensor(self.class_to_idx[label], dtype=torch.int32)
        keypoint_original = torch.tensor(self.get_keypoint(bboxes_original), dtype=torch.float32)
        # img_original = cv2.imread(img_path)
        # img_original = cv2.cvtColor(img_original, cv2.COLOR_BGR2RGB)
        # img_original = read_image(img_path).float() / 255.0
        img_original = Image.open(img_path).convert('RGB')

        if self.transform:
            img_original = self.transform(img_original)

        return img_original, keypoint_original, label

In [4]:
def collate_fn(batch):
    return tuple(zip(*batch))
    
class_config_path = '../../config/formated_class.yaml'
KEYPOINTS_FOLDER_TRAIN = '../../dataset/robocup_all_test/'
train_path = os.path.join(KEYPOINTS_FOLDER_TRAIN,"train/")
val_path = os.path.join(KEYPOINTS_FOLDER_TRAIN,"val/")

dataset = ClassDataset(train_path, class_config_path, transform=transform, demo=True)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)

{'AllenKey': 0, 'Axis2': 1, 'Bearing2': 2, 'Drill': 3, 'F20_20_B': 4, 'F20_20_G': 5, 'Housing': 6, 'M20': 7, 'M20_100': 8, 'M30': 9, 'Motor2': 10, 'S40_40_B': 11, 'S40_40_G': 12, 'Screwdriver': 13, 'Spacer': 14, 'Wrench': 15, 'container_blue': 16, 'container_red': 17}


In [5]:
iterator = iter(dataloader)
images, keypoint, label = next(iterator)
print(label)
print(keypoint)
image_number = 0

(tensor(5, dtype=torch.int32), tensor(0, dtype=torch.int32), tensor(12, dtype=torch.int32), tensor(5, dtype=torch.int32), tensor(7, dtype=torch.int32), tensor(0, dtype=torch.int32), tensor(11, dtype=torch.int32), tensor(4, dtype=torch.int32), tensor(7, dtype=torch.int32), tensor(11, dtype=torch.int32), tensor(7, dtype=torch.int32), tensor(1, dtype=torch.int32), tensor(7, dtype=torch.int32), tensor(7, dtype=torch.int32), tensor(0, dtype=torch.int32), tensor(12, dtype=torch.int32), tensor(0, dtype=torch.int32), tensor(4, dtype=torch.int32), tensor(12, dtype=torch.int32), tensor(12, dtype=torch.int32), tensor(12, dtype=torch.int32), tensor(9, dtype=torch.int32), tensor(2, dtype=torch.int32), tensor(9, dtype=torch.int32), tensor(7, dtype=torch.int32), tensor(9, dtype=torch.int32), tensor(1, dtype=torch.int32), tensor(2, dtype=torch.int32), tensor(2, dtype=torch.int32), tensor(2, dtype=torch.int32), tensor(3, dtype=torch.int32), tensor(6, dtype=torch.int32))
(tensor([[244.5000, 327.5000]]),

In [6]:
class ResNetModified(nn.Module):
    # def __init__(self, base_model, num_classes):
    #     super().__init__()
    #     self.resnet = base_model(pretrained=True)
    #     self.resnet.fc = nn.Linear(self.resnet.fc.in_features, num_classes + 2)  # classes + x,y for center point

    # def forward(self, x):
    #     return self.resnet(x)
    def __init__(self, base_model, num_classes):
        super(ResNetModified, self).__init__()
        self.resnet = base_model(pretrained=True)
        self.features = nn.Sequential(*list(self.resnet.children())[:-1])
        self.classifier = nn.Linear(self.resnet.fc.in_features, num_classes)
        self.keypoints = nn.Linear(self.resnet.fc.in_features, 2)  # x, y coordinates

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        
        class_output = self.classifier(x)
        keypoints_output = self.keypoints(x)
        
        return class_output, keypoints_output


In [7]:
num_classes = len(load_class_info(class_config_path))
model = ResNetModified(models.resnet50, num_classes)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device) 

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion_cls = nn.CrossEntropyLoss()
criterion_pts = nn.MSELoss()



In [8]:
def train_model(model, dataloader, optimizer, num_epochs=50):
    model.train()
    loss_cls = []
    loss_kpt = []
    for epoch in range(num_epochs):
        epoch_loss = 0
        for images, keypoints, labels in dataloader:
            images, keypoints, labels = torch.stack(images).to(device), torch.stack(keypoints).to(device), torch.stack(labels).to(device)
            optimizer.zero_grad()
            class_outputs, keypoint_outputs = model(images)
            loss_cls = criterion_cls(class_outputs, labels)
            loss_pts = criterion_pts(keypoint_outputs, keypoints)
            loss = loss_pts
            loss.backward()
            optimizer.step()
            train_epoch_loss += loss_cls.item()
        train_epoch_loss /= len(dataloader.dataset)
        loss_cls.append(epoch_loss)
        loss_kpt.append(epoch_loss)
        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}')
    return model, loss_history

In [10]:
num_epochs = 100
model, loss_history = train_model(model, dataloader, optimizer, num_epochs)
torch.save(model.state_dict(), './resnet50_normal.pth')

  return F.mse_loss(input, target, reduction=self.reduction)


Epoch 1/100, Loss: 1522.3250
Epoch 2/100, Loss: 531.0817


KeyboardInterrupt: 

In [115]:
plt.plot(loss_history)
plt.title('Training Loss Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.show()
plt.savefig('./kp_loss_train_val_'+.png', format='png')

SyntaxError: unterminated string literal (detected at line 6) (2431911355.py, line 6)

In [108]:
num_epochs = 100
model, loss_history = train_model(model, dataloader, optimizer, num_epochs)
torch.save(model.state_dict(), './resnet50_normal_2.pth')

RuntimeError: The size of tensor a (20) must match the size of tensor b (2) at non-singleton dimension 2

In [9]:
model.load_state_dict(torch.load('./models/resnet50_normal_1.pth'))
model.eval()  # Switch to evaluation mode
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
print("successful")

RuntimeError: Error(s) in loading state_dict for ResNetModified:
	Missing key(s) in state_dict: "features.0.weight", "features.1.weight", "features.1.bias", "features.1.running_mean", "features.1.running_var", "features.4.0.conv1.weight", "features.4.0.bn1.weight", "features.4.0.bn1.bias", "features.4.0.bn1.running_mean", "features.4.0.bn1.running_var", "features.4.0.conv2.weight", "features.4.0.bn2.weight", "features.4.0.bn2.bias", "features.4.0.bn2.running_mean", "features.4.0.bn2.running_var", "features.4.0.conv3.weight", "features.4.0.bn3.weight", "features.4.0.bn3.bias", "features.4.0.bn3.running_mean", "features.4.0.bn3.running_var", "features.4.0.downsample.0.weight", "features.4.0.downsample.1.weight", "features.4.0.downsample.1.bias", "features.4.0.downsample.1.running_mean", "features.4.0.downsample.1.running_var", "features.4.1.conv1.weight", "features.4.1.bn1.weight", "features.4.1.bn1.bias", "features.4.1.bn1.running_mean", "features.4.1.bn1.running_var", "features.4.1.conv2.weight", "features.4.1.bn2.weight", "features.4.1.bn2.bias", "features.4.1.bn2.running_mean", "features.4.1.bn2.running_var", "features.4.1.conv3.weight", "features.4.1.bn3.weight", "features.4.1.bn3.bias", "features.4.1.bn3.running_mean", "features.4.1.bn3.running_var", "features.4.2.conv1.weight", "features.4.2.bn1.weight", "features.4.2.bn1.bias", "features.4.2.bn1.running_mean", "features.4.2.bn1.running_var", "features.4.2.conv2.weight", "features.4.2.bn2.weight", "features.4.2.bn2.bias", "features.4.2.bn2.running_mean", "features.4.2.bn2.running_var", "features.4.2.conv3.weight", "features.4.2.bn3.weight", "features.4.2.bn3.bias", "features.4.2.bn3.running_mean", "features.4.2.bn3.running_var", "features.5.0.conv1.weight", "features.5.0.bn1.weight", "features.5.0.bn1.bias", "features.5.0.bn1.running_mean", "features.5.0.bn1.running_var", "features.5.0.conv2.weight", "features.5.0.bn2.weight", "features.5.0.bn2.bias", "features.5.0.bn2.running_mean", "features.5.0.bn2.running_var", "features.5.0.conv3.weight", "features.5.0.bn3.weight", "features.5.0.bn3.bias", "features.5.0.bn3.running_mean", "features.5.0.bn3.running_var", "features.5.0.downsample.0.weight", "features.5.0.downsample.1.weight", "features.5.0.downsample.1.bias", "features.5.0.downsample.1.running_mean", "features.5.0.downsample.1.running_var", "features.5.1.conv1.weight", "features.5.1.bn1.weight", "features.5.1.bn1.bias", "features.5.1.bn1.running_mean", "features.5.1.bn1.running_var", "features.5.1.conv2.weight", "features.5.1.bn2.weight", "features.5.1.bn2.bias", "features.5.1.bn2.running_mean", "features.5.1.bn2.running_var", "features.5.1.conv3.weight", "features.5.1.bn3.weight", "features.5.1.bn3.bias", "features.5.1.bn3.running_mean", "features.5.1.bn3.running_var", "features.5.2.conv1.weight", "features.5.2.bn1.weight", "features.5.2.bn1.bias", "features.5.2.bn1.running_mean", "features.5.2.bn1.running_var", "features.5.2.conv2.weight", "features.5.2.bn2.weight", "features.5.2.bn2.bias", "features.5.2.bn2.running_mean", "features.5.2.bn2.running_var", "features.5.2.conv3.weight", "features.5.2.bn3.weight", "features.5.2.bn3.bias", "features.5.2.bn3.running_mean", "features.5.2.bn3.running_var", "features.5.3.conv1.weight", "features.5.3.bn1.weight", "features.5.3.bn1.bias", "features.5.3.bn1.running_mean", "features.5.3.bn1.running_var", "features.5.3.conv2.weight", "features.5.3.bn2.weight", "features.5.3.bn2.bias", "features.5.3.bn2.running_mean", "features.5.3.bn2.running_var", "features.5.3.conv3.weight", "features.5.3.bn3.weight", "features.5.3.bn3.bias", "features.5.3.bn3.running_mean", "features.5.3.bn3.running_var", "features.6.0.conv1.weight", "features.6.0.bn1.weight", "features.6.0.bn1.bias", "features.6.0.bn1.running_mean", "features.6.0.bn1.running_var", "features.6.0.conv2.weight", "features.6.0.bn2.weight", "features.6.0.bn2.bias", "features.6.0.bn2.running_mean", "features.6.0.bn2.running_var", "features.6.0.conv3.weight", "features.6.0.bn3.weight", "features.6.0.bn3.bias", "features.6.0.bn3.running_mean", "features.6.0.bn3.running_var", "features.6.0.downsample.0.weight", "features.6.0.downsample.1.weight", "features.6.0.downsample.1.bias", "features.6.0.downsample.1.running_mean", "features.6.0.downsample.1.running_var", "features.6.1.conv1.weight", "features.6.1.bn1.weight", "features.6.1.bn1.bias", "features.6.1.bn1.running_mean", "features.6.1.bn1.running_var", "features.6.1.conv2.weight", "features.6.1.bn2.weight", "features.6.1.bn2.bias", "features.6.1.bn2.running_mean", "features.6.1.bn2.running_var", "features.6.1.conv3.weight", "features.6.1.bn3.weight", "features.6.1.bn3.bias", "features.6.1.bn3.running_mean", "features.6.1.bn3.running_var", "features.6.2.conv1.weight", "features.6.2.bn1.weight", "features.6.2.bn1.bias", "features.6.2.bn1.running_mean", "features.6.2.bn1.running_var", "features.6.2.conv2.weight", "features.6.2.bn2.weight", "features.6.2.bn2.bias", "features.6.2.bn2.running_mean", "features.6.2.bn2.running_var", "features.6.2.conv3.weight", "features.6.2.bn3.weight", "features.6.2.bn3.bias", "features.6.2.bn3.running_mean", "features.6.2.bn3.running_var", "features.6.3.conv1.weight", "features.6.3.bn1.weight", "features.6.3.bn1.bias", "features.6.3.bn1.running_mean", "features.6.3.bn1.running_var", "features.6.3.conv2.weight", "features.6.3.bn2.weight", "features.6.3.bn2.bias", "features.6.3.bn2.running_mean", "features.6.3.bn2.running_var", "features.6.3.conv3.weight", "features.6.3.bn3.weight", "features.6.3.bn3.bias", "features.6.3.bn3.running_mean", "features.6.3.bn3.running_var", "features.6.4.conv1.weight", "features.6.4.bn1.weight", "features.6.4.bn1.bias", "features.6.4.bn1.running_mean", "features.6.4.bn1.running_var", "features.6.4.conv2.weight", "features.6.4.bn2.weight", "features.6.4.bn2.bias", "features.6.4.bn2.running_mean", "features.6.4.bn2.running_var", "features.6.4.conv3.weight", "features.6.4.bn3.weight", "features.6.4.bn3.bias", "features.6.4.bn3.running_mean", "features.6.4.bn3.running_var", "features.6.5.conv1.weight", "features.6.5.bn1.weight", "features.6.5.bn1.bias", "features.6.5.bn1.running_mean", "features.6.5.bn1.running_var", "features.6.5.conv2.weight", "features.6.5.bn2.weight", "features.6.5.bn2.bias", "features.6.5.bn2.running_mean", "features.6.5.bn2.running_var", "features.6.5.conv3.weight", "features.6.5.bn3.weight", "features.6.5.bn3.bias", "features.6.5.bn3.running_mean", "features.6.5.bn3.running_var", "features.7.0.conv1.weight", "features.7.0.bn1.weight", "features.7.0.bn1.bias", "features.7.0.bn1.running_mean", "features.7.0.bn1.running_var", "features.7.0.conv2.weight", "features.7.0.bn2.weight", "features.7.0.bn2.bias", "features.7.0.bn2.running_mean", "features.7.0.bn2.running_var", "features.7.0.conv3.weight", "features.7.0.bn3.weight", "features.7.0.bn3.bias", "features.7.0.bn3.running_mean", "features.7.0.bn3.running_var", "features.7.0.downsample.0.weight", "features.7.0.downsample.1.weight", "features.7.0.downsample.1.bias", "features.7.0.downsample.1.running_mean", "features.7.0.downsample.1.running_var", "features.7.1.conv1.weight", "features.7.1.bn1.weight", "features.7.1.bn1.bias", "features.7.1.bn1.running_mean", "features.7.1.bn1.running_var", "features.7.1.conv2.weight", "features.7.1.bn2.weight", "features.7.1.bn2.bias", "features.7.1.bn2.running_mean", "features.7.1.bn2.running_var", "features.7.1.conv3.weight", "features.7.1.bn3.weight", "features.7.1.bn3.bias", "features.7.1.bn3.running_mean", "features.7.1.bn3.running_var", "features.7.2.conv1.weight", "features.7.2.bn1.weight", "features.7.2.bn1.bias", "features.7.2.bn1.running_mean", "features.7.2.bn1.running_var", "features.7.2.conv2.weight", "features.7.2.bn2.weight", "features.7.2.bn2.bias", "features.7.2.bn2.running_mean", "features.7.2.bn2.running_var", "features.7.2.conv3.weight", "features.7.2.bn3.weight", "features.7.2.bn3.bias", "features.7.2.bn3.running_mean", "features.7.2.bn3.running_var", "classifier.weight", "classifier.bias", "keypoints.weight", "keypoints.bias". 
	size mismatch for resnet.fc.weight: copying a param with shape torch.Size([20, 2048]) from checkpoint, the shape in current model is torch.Size([1000, 2048]).
	size mismatch for resnet.fc.bias: copying a param with shape torch.Size([20]) from checkpoint, the shape in current model is torch.Size([1000]).