In [1]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
import os
import pandas as pd
import os
from tqdm.notebook import tqdm
from PIL import Image

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 8
IMG_WIDTH = 224
IMG_HEIGHT = 224

In [3]:
target_files = ["1652875851.3497071", "1652875901.3107166", "1652876013.741493", "1652876206.2541456", "1652876485.8123376", "1652959186.4507334",
                "1652959347.972946", "1653042695.4914637", "1653042775.5213027", "1653043202.5073502", "1653043345.3415065", "1653043428.8546412", "1653043549.5187616"]

In [4]:
random_transforms = transforms.Compose([
    transforms.RandomRotation(5),
    transforms.RandomResizedCrop((IMG_WIDTH, IMG_HEIGHT), scale=(.9, 1), ratio=(1, 1)),
    transforms.ColorJitter(brightness=.2, contrast=.5,saturation=0.5,hue=0.5)
])

In [5]:
class ImageDataset(Dataset):
    def __init__(self, files):
        self.image_dirs = files
        self.targets = pd.DataFrame(columns=['path', 'forward', 'left'])
        for image_dir in self.image_dirs:
            def change_name(number:float) -> str:
                true_num = int(number)
                return f"{os.path.join(image_dir, f'{true_num:04}.jpg')}"
            frame = pd.read_csv(os.path.join('dataset', f'{image_dir}.csv'), names=['path', 'forward', 'left'])
            frame['path'] = frame['path'].apply(change_name)
            self.targets = pd.concat([self.targets, frame])
            self.paths:pd.Series[str] = self.targets.pop('path')
            break

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        image_path = self.paths[idx]
        target = torch.tensor(self.targets.iloc[idx].values,dtype=torch.float32)
        image = transforms.ToTensor()(Image.open(os.path.join('dataset', image_path)))
        image = image.to(device)
        target = target.to(device)
        return image, target


In [6]:
img_dataset = ImageDataset(target_files)

  self.targets = pd.concat([self.targets, frame])


In [7]:
train_dataset, val_dataset = random_split(img_dataset, [0.8, 0.2])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)

In [8]:
class DrivingModel(nn.Sequential):
    def __init__(self):
        super(DrivingModel, self).__init__()
        self.input_shape = (1, 3, IMG_WIDTH, IMG_HEIGHT)
        
        self.append(nn.Conv2d(3, 16, kernel_size=3, padding=1))
        #self.append(nn.BatchNorm2d(16))
        self.append(nn.ReLU())
        #self.append(nn.Dropout2d(.3))
        self.append(nn.MaxPool2d(kernel_size=2, stride=2))
        
        self.append(nn.Conv2d(16, 32, kernel_size=3, padding=1))
        #self.append(nn.BatchNorm2d(32))
        self.append(nn.ReLU())
        #self.append(nn.Dropout2d(.3))
        self.append(nn.MaxPool2d(kernel_size=2, stride=2))
        
        self.append(nn.Conv2d(32, 64, kernel_size=3, padding=1))
        #self.append(nn.BatchNorm2d(64))
        self.append(nn.ReLU())
        #self.append(nn.Dropout2d(.3))
        self.append(nn.MaxPool2d(kernel_size=2, stride=2))
        
        self.append(nn.Flatten())
        
        self.append(nn.Linear(64 * 28 * 28, 64))
        #self.append(nn.Dropout1d(.3))
        self.append(nn.ReLU())
        self.append(nn.Linear(64,2))

    def forward(self, x):
        x = x.view(-1, 3, 224, 224)
        for layer in self:
            x = layer(x)
        return x

In [9]:
model= DrivingModel().to(device)
model = torch.compile(model)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters())

In [10]:
num_epochs = 200
best_val_loss = None
counter = 10
for epoch in range(num_epochs):
    train_loss = 0
    model.train()
    for images, targets in tqdm(train_loader, desc="Training"):
        optimizer.zero_grad()
        images = random_transforms(images)
        outputs = model(images)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    print(f"Epoch {epoch+1}/{num_epochs}, Training Loss: {train_loss / len(train_loader)}")
    
    val_loss = 0
    model.eval()
    for images, targets in tqdm(val_loader, desc="Validation"):
        images = random_transforms(images)
        images = images.to(device)
        targets = targets.to(device)
        with torch.inference_mode():
            outputs = model(images)
            loss = criterion(outputs, targets)
            val_loss += loss.item()
    if best_val_loss is None:
        best_val_loss = val_loss
    elif best_val_loss < val_loss:
        counter -=1
    else:
        best_val_loss = val_loss
        counter = 10
    if not counter:
        break
    print(f"Epoch {epoch+1}/{num_epochs}, Validation Loss: {val_loss / len(val_loader)}")
torch.save(model, 'model.pth')

Training:   0%|          | 0/19 [00:00<?, ?it/s]

W0518 12:10:30.418000 16848 torch/_inductor/utils.py:1250] [0/0] Not enough SMs to use max_autotune_gemm mode


Epoch 1/200, Training Loss: 0.3821396971807668


Validation:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 1/200, Validation Loss: 0.1484855130314827


Training:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 2/200, Training Loss: 0.12424746278281275


Validation:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 2/200, Validation Loss: 0.1371241280809045


Training:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 3/200, Training Loss: 0.12266201324956982


Validation:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 3/200, Validation Loss: 0.13245464004576207


Training:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 4/200, Training Loss: 0.11787191896062148


Validation:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 4/200, Validation Loss: 0.12189478976652027


Training:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 5/200, Training Loss: 0.12595113303120198


Validation:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 5/200, Validation Loss: 0.12579903416335583


Training:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 6/200, Training Loss: 0.12766715590106814


Validation:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 6/200, Validation Loss: 0.10655075088143348


Training:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 7/200, Training Loss: 0.1155661940574646


Validation:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 7/200, Validation Loss: 0.12685466296970843


Training:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 8/200, Training Loss: 0.11021761998142067


Validation:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 8/200, Validation Loss: 0.10139838494360447


Training:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 9/200, Training Loss: 0.1034898171691518


Validation:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 9/200, Validation Loss: 0.10425234586000443


Training:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 10/200, Training Loss: 0.10778413296334054


Validation:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 10/200, Validation Loss: 0.09990274757146836


Training:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 11/200, Training Loss: 0.10748783136276822


Validation:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 11/200, Validation Loss: 0.09399459846317768


Training:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 12/200, Training Loss: 0.1055095961415454


Validation:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 12/200, Validation Loss: 0.10576561912894249


Training:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 13/200, Training Loss: 0.10617508780897449


Validation:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 13/200, Validation Loss: 0.09580107647925615


Training:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 14/200, Training Loss: 0.11245216620399763


Validation:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 14/200, Validation Loss: 0.09993486888706685


Training:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 15/200, Training Loss: 0.09453927191268456


Validation:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 15/200, Validation Loss: 0.0845951821655035


Training:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 16/200, Training Loss: 0.09371822175422781


Validation:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 16/200, Validation Loss: 0.10383734423667193


Training:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 17/200, Training Loss: 0.09755186775797292


Validation:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 17/200, Validation Loss: 0.09968255124986172


Training:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 18/200, Training Loss: 0.09297017215162907


Validation:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 18/200, Validation Loss: 0.09250237978994846


Training:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 19/200, Training Loss: 0.10224556638614128


Validation:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 19/200, Validation Loss: 0.08024129644036293


Training:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 20/200, Training Loss: 0.09749987380775182


Validation:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 20/200, Validation Loss: 0.09058868708088993


Training:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 21/200, Training Loss: 0.11618767796378386


Validation:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 21/200, Validation Loss: 0.08424396216869354


Training:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 22/200, Training Loss: 0.09298360117368008


Validation:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 22/200, Validation Loss: 0.08186916653066874


Training:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 23/200, Training Loss: 0.08999599273757715


Validation:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 23/200, Validation Loss: 0.08009761944413185


Training:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 24/200, Training Loss: 0.08816062217872393


Validation:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 24/200, Validation Loss: 0.09224732238799334


Training:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 25/200, Training Loss: 0.08439578900211736


Validation:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 25/200, Validation Loss: 0.09056157059967518


Training:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 26/200, Training Loss: 0.07937842342806489


Validation:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 26/200, Validation Loss: 0.10193339660763741


Training:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 27/200, Training Loss: 0.09899449110717366


Validation:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 27/200, Validation Loss: 0.08568283170461655


Training:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 28/200, Training Loss: 0.09256305763694017


Validation:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 28/200, Validation Loss: 0.10369568802416325


Training:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 29/200, Training Loss: 0.0973001847925939


Validation:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 29/200, Validation Loss: 0.08137013465166092


Training:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 30/200, Training Loss: 0.08991248709590811


Validation:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 30/200, Validation Loss: 0.08862639591097832


Training:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 31/200, Training Loss: 0.09513672019698118


Validation:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 31/200, Validation Loss: 0.08530332632362843


Training:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 32/200, Training Loss: 0.08011143968293541


Validation:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 32/200, Validation Loss: 0.06815932504832745


Training:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 33/200, Training Loss: 0.08033738471567631


Validation:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 33/200, Validation Loss: 0.09276652783155441


Training:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 34/200, Training Loss: 0.08645723804243301


Validation:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 34/200, Validation Loss: 0.0899450683966279


Training:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 35/200, Training Loss: 0.08384574498785169


Validation:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 35/200, Validation Loss: 0.07406484819948674


Training:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 36/200, Training Loss: 0.08833746560604165


Validation:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 36/200, Validation Loss: 0.09695865362882614


Training:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 37/200, Training Loss: 0.0930627707980181


Validation:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 37/200, Validation Loss: 0.09634248353540897


Training:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 38/200, Training Loss: 0.09981868220003028


Validation:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 38/200, Validation Loss: 0.08567083440721035


Training:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 39/200, Training Loss: 0.08898926778745495


Validation:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 39/200, Validation Loss: 0.08347911573946476


Training:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 40/200, Training Loss: 0.10380937529139612


Validation:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 40/200, Validation Loss: 0.07391533330082893


Training:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 41/200, Training Loss: 0.09395777965944849


Validation:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 41/200, Validation Loss: 0.096139794588089


Training:   0%|          | 0/19 [00:00<?, ?it/s]

Epoch 42/200, Training Loss: 0.08916492946445942


Validation:   0%|          | 0/5 [00:00<?, ?it/s]