In [2]:
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from torchmetrics.classification import Accuracy
import torch.nn.functional as F
import gymnasium as gym
%load_ext autoreload
%autoreload 2
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cuda device


In [76]:
class AngleNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 4, 3)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(7*7*4, 16)

    def forward(self, x):
        x.to(device)
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = x.to("cpu")
        return x

    def prep_input(self, data, x_coords, y_coords):
        angle_data = torch.zeros([data.shape[0], 1, 16, 16]).to(device)
        for idx, item in enumerate(data):
            angle_data[idx, 0, :, :] = item[0].roll([-105-y_coords[idx], -75-x_coords[idx]], [0, 1])[97:113, 80:96]
        return angle_data

In [4]:
class XNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 96, 5, padding=2)
        self.fc1 = nn.Linear(160*96, 160)
    def forward(self, x):
        x = x.to(device)
        x = F.relu(self.conv1(x))
        x = torch.amax(x, dim=2)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = x.to("cpu")
        return x

In [5]:
class YNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 96, 5, padding=2)
        self.fc1 = nn.Linear(210*96, 210)

    def forward(self, x):
        x = x.to(device)
        x = F.relu(self.conv1(x))
        x = torch.amax(x, dim=3)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = x.to("cpu")
        return x

In [65]:
class CerberusDataset(Dataset):
    def __init__(self, num_samples):
        self.num_samples = num_samples
        self.env = gym.make('AsteroidsNoFrameskip-v4', obs_type="grayscale")
        obs, info = self.env.reset()
        angle_states = torch.zeros([16, 210, 160])
        for i in range(16):
            self.env.step(4)
            self.env.step(4)
            self.env.step(4)
            obs, reward, terminated, truncated, info = self.env.step(4)
            angle_states[i] = torch.roll(torch.from_numpy(obs)/256, [105, 75], dims=[0, 1])
        x_shifts = torch.randint(0, 160, [num_samples])
        y_shifts = torch.randint(0, 210, [num_samples])
        angles = torch.randint(0, 16, [num_samples])
        data = torch.zeros([num_samples, 1, 210, 160])
        for i in range(num_samples):
            data[i][0] = torch.roll(angle_states[angles[i]], [y_shifts[i], x_shifts[i]], [0, 1])
        self.data = data.to(device)
        self.x_shifts = F.one_hot(x_shifts, num_classes=160).to(torch.float32)
        self.y_shifts = F.one_hot(y_shifts, num_classes=210).to(torch.float32)
        self.angles = F.one_hot(angles, num_classes=16).to(torch.float32)
    def __len__(self):
        return self.num_samples
    def __getitem__(self, idx):
        return self.data[idx], self.x_shifts[idx], self.y_shifts[idx], self.angles[idx]

In [96]:
loss_fn = nn.CrossEntropyLoss()
test_dataset = CerberusDataset(1000)
x_net = XNet().to(device)
y_net = YNet().to(device)
angle_net = AngleNet().to(device)

In [97]:
x_lr = 1e-3
y_lr = 1e-3
angle_lr = 1e-3
epochs = 8
angle_start = 0
batch_size = 32
num_samples = 1024
x_optim = Adam(x_net.parameters(), lr=x_lr)
y_optim = Adam(y_net.parameters(), lr=y_lr)
angle_optim = Adam(angle_net.parameters(), lr=angle_lr)
for epoch in range(epochs):
    dataloader = DataLoader(CerberusDataset(num_samples), batch_size=batch_size)
    for step, [data, x_shifts, y_shifts, angles] in enumerate(dataloader):

        x_pred = x_net(data)
        x_optim.zero_grad()
        x_loss = loss_fn(x_pred, x_shifts)
        x_loss.backward()
        x_optim.step()

        y_pred = y_net(data)
        y_optim.zero_grad()
        y_loss = loss_fn(y_pred, y_shifts)
        y_loss.backward()
        y_optim.step()
        
        if epoch > angle_start: # Only train angle network once x and y are relatively good
            x_coords = torch.argmax(x_net(data), axis=1)
            y_coords = torch.argmax(y_net(data), axis=1)
            angle_data = angle_net.prep_input(data, x_coords, y_coords)
            angle_pred = angle_net(angle_data)
            angle_optim.zero_grad()
            angle_loss = loss_fn(angle_pred, angles)
            angle_loss.backward()
            angle_optim.step()
# Test models
x_accuracy = Accuracy(task="multiclass", num_classes=160).to(device)
y_accuracy = Accuracy(task="multiclass", num_classes=210).to(device)
angle_accuracy = Accuracy(task="multiclass", num_classes=16).to(device)

test_dataloader = DataLoader(test_dataset, batch_size=32)
for step, [data, x_shifts, y_shifts, angles] in enumerate(test_dataloader):
    x_pred = torch.argmax(x_net(data), axis=1)
    x_accuracy.update(x_pred, torch.argmax(x_shifts, axis=1))

    y_pred = torch.argmax(y_net(data), axis=1)
    y_accuracy.update(y_pred, torch.argmax(y_shifts, axis=1))
    
    x_coords = torch.argmax(x_net(data), axis=1)
    y_coords = torch.argmax(y_net(data), axis=1)
    angle_data = angle_net.prep_input(data, x_coords, y_coords)
    angle_pred = torch.argmax(angle_net(angle_data), axis=1)
    angle_accuracy.update(angle_pred, torch.argmax(angles, axis=1))
print(f'X:{x_accuracy.compute()}, Y:{y_accuracy.compute()}, Angle:{angle_accuracy.compute()}')

X:1.0, Y:1.0, Angle:1.0
