In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
import convLSTM as convLSTM
from pathlib import Path
from time import time
from PIL import Image
from RNN_classes_funcs_Marchese import *

  except ModuleNotFoundError: warn("Missing `graphviz` - please run `conda install fastbook`")


In [2]:
device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=3)

In [3]:
# Get classes and filenames
path = Path("data_RNN")
classes = get_classes(path)
all_filenames = get_filenames(path)
all_filenames.sort()

In [4]:
# Getting size of dataset and corresponding list of indices
dataset_size = len(all_filenames)
dataset_indices = list(range(dataset_size))

In [5]:
# Getting index for where we want to split the data
val_split_index = int(np.floor(0.2 * dataset_size))

In [6]:
# Splitting list of indices into training and validation indices
train_idx, val_idx = dataset_indices[val_split_index:], dataset_indices[:val_split_index]

In [7]:
# Getting list of filenames for training and validation set
train_filenames = [all_filenames[i] for i in train_idx]
val_filenames = [all_filenames[i] for i in val_idx]
train_filenames

[Path('data_RNN/01390-straight-straight.png'),
 Path('data_RNN/01391-straight-straight.png'),
 Path('data_RNN/01392-straight-straight.png'),
 Path('data_RNN/01393-straight-straight.png'),
 Path('data_RNN/01394-straight-straight.png'),
 Path('data_RNN/01395-straight-straight.png'),
 Path('data_RNN/01396-straight-straight.png'),
 Path('data_RNN/01397-straight-straight.png'),
 Path('data_RNN/01398-straight-straight.png'),
 Path('data_RNN/01399-straight-straight.png'),
 Path('data_RNN/01400-straight-straight.png'),
 Path('data_RNN/01401-straight-straight.png'),
 Path('data_RNN/01402-straight-straight.png'),
 Path('data_RNN/01403-straight-straight.png'),
 Path('data_RNN/01404-straight-straight.png'),
 Path('data_RNN/01405-straight-straight.png'),
 Path('data_RNN/01406-straight-straight.png'),
 Path('data_RNN/01407-straight-straight.png'),
 Path('data_RNN/01408-straight-straight.png'),
 Path('data_RNN/01409-straight-straight.png'),
 Path('data_RNN/01410-straight-straight.png'),
 Path('data_R

In [8]:
# Getting data via custom dataset
train_data = ImageDataset(classes, train_filenames)
val_data = ImageDataset(classes, val_filenames)

In [9]:
# Loading in data
train_loader = DataLoader(dataset=train_data, shuffle=False, batch_size=8)
val_loader = DataLoader(dataset=val_data, shuffle=False, batch_size=8)
img, label = next(iter(train_loader))

In [10]:
img.size()

torch.Size([8, 1, 3, 240, 320])

In [11]:
net = ConvRNN()
net.to(device)

ConvRNN(
  (convlstm): ConvLSTM(
    (cell_list): ModuleList(
      (0): ConvLSTMCell(
        (conv): Conv2d(13, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
    )
  )
  (flat): Flatten(start_dim=1, end_dim=-1)
  (lin1): Linear(in_features=768000, out_features=512, bias=True)
  (relu): ReLU()
  (lin2): Linear(in_features=512, out_features=3, bias=True)
)

In [12]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.0001)

In [13]:
num_epochs = 20

In [14]:
net.train()

for epoch in range(num_epochs):  # loop over the dataset multiple times

    running_loss = 0.0
    
    start = time()
    
    for data in train_loader:
        # Get the inputs and labels; currently ommitting cmd
        img, label = data
        
        # Putting data into the GPU
        img = img.to(device)
        label = label.to(device)
        
        # zero the parameter gradients
        optimizer.zero_grad()
        
        # forward + backward + optimize
        output = net(img)
        loss = criterion(output, label)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        
    print(f"Epoch:{epoch+1:}/{num_epochs}, Training Loss:{running_loss:0.1f}, Time:{time()-start:0.1f}s")

print('Finished Training')

Epoch:1/20, Training Loss:546.6, Time:63.6s
Epoch:2/20, Training Loss:350.6, Time:63.2s
Epoch:3/20, Training Loss:339.7, Time:63.0s
Epoch:4/20, Training Loss:288.5, Time:63.8s
Epoch:5/20, Training Loss:241.2, Time:63.3s
Epoch:6/20, Training Loss:217.0, Time:63.5s
Epoch:7/20, Training Loss:189.3, Time:64.3s
Epoch:8/20, Training Loss:143.5, Time:64.6s
Epoch:9/20, Training Loss:113.3, Time:64.6s
Epoch:10/20, Training Loss:88.9, Time:64.2s
Epoch:11/20, Training Loss:60.5, Time:64.5s
Epoch:12/20, Training Loss:39.5, Time:64.5s
Epoch:13/20, Training Loss:43.1, Time:63.9s
Epoch:14/20, Training Loss:30.0, Time:64.8s
Epoch:15/20, Training Loss:12.1, Time:63.3s
Epoch:16/20, Training Loss:8.1, Time:62.6s
Epoch:17/20, Training Loss:7.1, Time:64.5s
Epoch:18/20, Training Loss:7.2, Time:64.8s
Epoch:19/20, Training Loss:10.1, Time:63.9s
Epoch:20/20, Training Loss:29.9, Time:64.4s
Finished Training


In [15]:
# Checking accuracy on validation set

correct = 0
total = 0

# Variables to keep track of accuracy for each class
class_correct = [0 for _ in classes]
class_total = [0 for _ in classes]

net.eval()

with torch.no_grad():

    for data in val_loader:

        # Get the inputs and labels; currently ommitting cmd
        img, label = data
        
        # Putting data into the GPU
        img = img.to(device)
        label = label.to(device)


        # Predict
        output = net(img)
        
        # Assuming we always get batches
        for i in range(output.size()[0]):
                
            # Getting the predicted most probable move
            move = torch.argmax(output[i])
                
            if move == label[i]:
                class_correct[label[i]] += 1
                class_total[label[i]] += 1
                correct +=1
            else:
                class_total[label[i]] += 1
            total += 1
        
# Calculate and output total set accuracy 
accuracy = correct / total
print(f"Accuracy on validation set: {correct}/{total} = {accuracy*100:.2f}%")

# Calculate and show accuracy for each class
for i, cls in enumerate(classes):
    ccorrect = class_correct[i]
    ctotal = class_total[i]
    caccuracy = ccorrect / ctotal
    print(f"  Accuracy on {cls:>5} class: {ccorrect}/{ctotal} = {caccuracy*100:.2f}%")

Accuracy on validation set: 1111/1389 = 79.99%
  Accuracy on  left class: 112/143 = 78.32%
  Accuracy on right class: 233/283 = 82.33%
  Accuracy on straight class: 766/963 = 79.54%


In [18]:
PATH = 'torch_RNN.pth'
torch.save(net.state_dict(), PATH)