In [None]:
import gym, random, os, math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as tud
from atari_wrappers import make_atari, wrap_deepmind,LazyFrames
from tqdm import tqdm

In [None]:
# Create and wrap the environment
env = make_atari('PongNoFrameskip-v4')
env = wrap_deepmind(env, scale = False, frame_stack=True )

In [None]:
class AEMemory(object):
    def __init__(self, memory_size=100000):
        self.buffer = []
        self.memory_size = memory_size
        self.next_idx = 0
        
    def push(self, state):
        state = state._force().transpose(2,0,1)[None]/255.
        if len(self.buffer) <= self.memory_size: 
            self.buffer.append(state)
        else: # buffer is full
            self.buffer[self.next_idx] = state
        self.next_idx = (self.next_idx + 1) % self.memory_size

    def size(self):
        return len(self.buffer)

## Auto-encoder

In [None]:
class AutoEncoder(nn.Module):
    def __init__(self, in_channels, hidden_dim):
        super(AutoEncoder, self).__init__()
        self.encoder_conv = nn.Sequential(
            nn.Conv2d(in_channels, 32, kernel_size=8, stride=4), # bs*32*19*19
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2), # bs*64*9*9
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1), # bs*64*7*7
            nn.ReLU()
        )
        self.encoder_linear = nn.Sequential(
            nn.Linear(7 * 7 * 64, 512), # bs*512
            nn.ReLU(),
            nn.Linear(512, hidden_dim), # bs*hid_dim
            nn.ReLU()
        )
        self.decoder_linear = nn.Sequential(
            nn.Linear(hidden_dim, 512), # bs*512
            nn.ReLU(),
            nn.Linear(512, 7 * 7 * 64), # bs*64*7*7
            nn.ReLU()
        )
        self.decoder_conv = nn.Sequential(
            nn.ConvTranspose2d(64, 64, kernel_size=3, stride=1), # bs*64*9*9
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2), # bs*32*19*19
            nn.ReLU(),
            nn.ConvTranspose2d(32, in_channels, kernel_size=8, stride=4) # bs*4*80*80
        )
        
    def forward(self,x):
        ## encoder
        hidden = self.encoder_conv(x)
        hidden = hidden.reshape(hidden.size(0),-1)
        hidden = self.encoder_linear(hidden)
        ## decoder
        output = self.decoder_linear(hidden)
        output = output.reshape(output.size(0),64,7,7)
        output = self.decoder_conv(output)
        return output

In [None]:
## fill the memory
memory = AEMemory(100000)

for _ in tqdm(range(20)):
    frame = env.reset()
    done = False
    while not done:
        action = random.randrange(env.action_space.n)
        next_frame, _, done, _ = env.step(action)
        memory.push(frame)
        frame = next_frame

In [None]:
## AutoEncoder Dataset
class AEDataset(tud.Dataset):
    def __init__(self,memory):
        self.states = torch.Tensor(memory.buffer).squeeze(1)
    
    def __len__(self):
        return len(self.states)
            
    def __getitem__(self, idx):
        return self.states[idx]

In [None]:
aedataset = AEDataset(memory)
dataloader = tud.DataLoader(aedataset,batch_size=32,shuffle=True)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = AutoEncoder(in_channels = 4, hidden_dim = 6).to(device)
loss_fn = nn.SmoothL1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr = 5e-4)

for e in range(20):
    losses = []
    for i, batch in enumerate(dataloader):
        batch = batch.to(device)
        output = model(batch)
        loss = loss_fn(batch, output)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        losses.append(loss.item())
    
    print("epoch:", e, "loss:", round(np.mean(losses),4))

In [None]:
torch.save(model.encoder_conv,'saved_model/daqn_pre_conv')
torch.save(model.encoder_linear,'saved_model/daqn_pre_linear')