In [None]:
# unzip file zip
!unzip -q 

In [None]:
# add lib for project
import torch
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
import glob
import PIL.Image
import os
import numpy as np

In [None]:
def get_x(path, width):
    """Gets the x value from the image filename"""
    return (float(int(path.split("_")[2])) - width/2) / (width/2)

def get_y(path, height):
    """Gets the y value from the image filename"""
    return (float(int(path.split("_")[3])) - height/2) / (height/2)

class XYDataset(torch.utils.data.Dataset):
    
    def __init__(self, directory, random_hflips=False):
        self.directory = directory
        self.random_hflips = random_hflips
        self.image_paths = glob.glob(os.path.join(self.directory, '*.jpg'))
        self.color_jitter = transforms.ColorJitter(0.3, 0.3, 0.3, 0.3)
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        
        image = PIL.Image.open(image_path)
        width, height = image.size
        x = float(get_x(os.path.basename(image_path), width))
        y = float(get_y(os.path.basename(image_path), height))
      
        if float(np.random.rand(1)) > 0.5:
            image = transforms.functional.hflip(image)
            x = -x
        
        image = self.color_jitter(image)
        image = transforms.functional.resize(image, (224, 224))
        image = transforms.functional.to_tensor(image)
        image = image.numpy()[::-1].copy()
        image = torch.from_numpy(image)
        image = transforms.functional.normalize(image, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        
        return image, torch.tensor([x, y]).float()
    
dataset = XYDataset("/content/drive/Othercomputers/MyLaptop/FileTrain/dataset_xy_vao", random_hflips=False)

In [None]:
test_percent = 0.1
num_test = int(test_percent * len(dataset))
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [len(dataset) - num_test, num_test])

In [None]:
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True,
    num_workers=0
)

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=8,
    shuffle=True,
    num_workers=0
)

In [None]:
model = models.resnet18(pretrained=True)

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and will be removed in 0.15, "
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


  0%|          | 0.00/44.7M [00:00<?, ?B/s]

In [None]:
model.fc = torch.nn.Linear(512, 2)
device = torch.device('cuda')
model = model.to(device)

In [None]:
%cd /content/drive/Othercomputers/MyLaptop/FileTrain

/content/drive/Othercomputers/MyLaptop/FileTrain


In [None]:
%pwd

'/content/drive/Othercomputers/MyLaptop/FileTrain'

In [None]:
NUM_EPOCHS = 300
BEST_MODEL_PATH = 'best_steering_model_xy_turn3.pth'
best_loss = 1e9

optimizer = optim.Adam(model.parameters())

for epoch in range(NUM_EPOCHS):
    
    model.train()
    train_loss = 0.0
    for images, labels in iter(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = F.mse_loss(outputs, labels)
        train_loss += float(loss)
        loss.backward()
        optimizer.step()
    train_loss /= len(train_loader)
    
    model.eval()
    test_loss = 0.0
    for images, labels in iter(test_loader):
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        loss = F.mse_loss(outputs, labels)
        test_loss += float(loss)
    test_loss /= len(test_loader)
    
    print('%d %f, %f' % (epoch,train_loss, test_loss))
    if test_loss < best_loss:
        torch.save(model.state_dict(), BEST_MODEL_PATH)
        best_loss = test_loss

0 1.188221, 1.493979
1 0.578388, 4.103417
2 0.149963, 0.535593
3 0.123476, 0.624579
4 0.123910, 0.037485
5 0.100705, 0.131554
6 0.096433, 0.031359
7 0.089622, 0.032508
8 0.055632, 0.024283
9 0.051490, 0.034867
10 0.103441, 0.091456
11 0.091748, 0.029263
12 0.116843, 0.048632
13 0.088862, 0.029905
14 0.040337, 0.042980
15 0.034081, 0.052821
16 0.087581, 0.029361
17 0.095983, 0.103731
18 0.070063, 0.041758
19 0.097883, 0.041786
20 0.061775, 0.104934
21 0.050779, 0.015359
22 0.028250, 0.080561
23 0.025801, 0.014362
24 0.025466, 0.016282
25 0.030598, 0.011210
26 0.027569, 0.021106
27 0.023031, 0.020023
28 0.026820, 0.020110
29 0.034606, 0.055017
30 0.025304, 0.018915
31 0.014265, 0.019442
32 0.022365, 0.017111
33 0.019413, 0.016715
34 0.023649, 0.030654
35 0.031191, 0.015216
36 0.028504, 0.038701
37 0.043415, 0.052536
38 0.053044, 0.021080
39 0.035316, 0.185886
40 0.022085, 0.035994
41 0.020837, 0.019762
42 0.019697, 0.021183
43 0.020704, 0.069861
44 0.021952, 0.035737
45 0.019595, 0.03414