In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.nn.functional as F  # Ensure this import is added

from torch.utils.data import DataLoader, TensorDataset
import gym
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import cv2

In [None]:

# Parameters
sequence_length = 4  # Number of images in each sequence
num_episodes = 1000   # Number of episodes for data collection

# Environment Setup
env = gym.make('CartPole-v1')
data_images = []
data_states = []

# Transformation for images
transform = transforms.Compose([
    transforms.Resize((64, 64)),  # Resize image to manageable size
    transforms.ToTensor()         # Convert image to PyTorch tensor
])

def heuristic_policy(observation):
    _, _, angle, _ = observation
    return 0 if angle < 0 else 1  # Move cart based on the angle of the pole

# Data Collection using Heuristic Policy
for episode in range(num_episodes):
    observation = env.reset()
    images = [torch.zeros(3, 64, 64) for _ in range(4)]
           
    for t in range(1000):
        img = env.render(mode='rgb_array')
        img_pil = Image.fromarray(img)
        tensor_image = transform(img_pil)  # Transform image immediately
        images.append(tensor_image)
        
        if len(images) >= sequence_length:
            # Stack the last sequence_length images to form a single sequence tensor
            sequence_tensor = torch.stack(images[-sequence_length:], dim=0).permute(1, 0, 2, 3)
            data_images.append(sequence_tensor)
            data_states.append(observation)
        
        action = env.action_space.sample()  
        observation, reward, done, info = env.step(action)
        if done:
            break

env.close()

# Convert data_states to a tensor
data_states = torch.tensor(data_states, dtype=torch.float32)

# Dataset and DataLoader
dataset = TensorDataset(torch.stack(data_images), data_states)
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)

# Model Definition
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv3d(3, 16, kernel_size=(3, 3, 3), stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)),
            nn.Conv3d(16, 32, kernel_size=(3, 3, 3), stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2)),
        )
        # Correctly calculate the input size for the linear layer based on the output from conv_layers
        self.fc_layers = nn.Sequential(
            nn.Linear(32 * 4 * 16 * 16, 128),  # Adjusted based on actual output size
            nn.ReLU(),
            nn.Linear(128, 4)  # Predicting 4 state variables
        )

    def forward(self, x):
        x = self.conv_layers(x)
        x = x.view(x.size(0), -1)  # Flatten the tensor for the fully connected layer
        x = self.fc_layers(x)
        return x

# Model instantiation and training setup
model = CNN()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Training Loop
num_epochs = 15 # Peut-être avec plus d'epoch on obtiendrait un meilleur résultat ? jsp
for epoch in range(num_epochs):
    for images, states in dataloader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, states)
        loss.backward()
        optimizer.step()
    print(f'Epoch {epoch+1}, Loss: {loss.item()}')

# Save the model
# torch.save(model.state_dict(), 'cartpole_cnn.pth')


In [None]:
# Voir un peu ce que ça donne

model.eval()
for images, states in dataloader:
    print(images.shape,states.shape)
    print(model(images[0].unsqueeze(0)),states[0])
    break

In [None]:
# Voir encore mieux ce que ça donne

# 1 : On collecte des images du cartpole (heuristique : random)

env = gym.make('CartPole-v1')
data_images_bis = []
data_states_bis = []

for episode in range(3):
    observation_bis = env.reset()
    images_bis = []
    for t in range(1000):
        img = env.render(mode='rgb_array')
        img_pil = Image.fromarray(img)
        tensor_image = transform(img_pil)  # Transform image immediately
        images_bis.append(tensor_image)
        
        if len(images_bis) >= sequence_length:
            # Stack the last sequence_length images to form a single sequence tensor
            sequence_tensor = torch.stack(images_bis[-sequence_length:], dim=0).permute(1, 0, 2, 3)
            data_images_bis.append(sequence_tensor)
            data_states_bis.append(observation)
        
        action = env.action_space.sample()   # Use the heuristic policy
        observation, reward, done, info = env.step(action)
        if done:
            break
data_states_bis = torch.tensor(data_states_bis, dtype=torch.float32)

# Dataset and DataLoader
data_images_bis = torch.stack(data_images_bis)

In [None]:
# 2 Afficher ce que prédit le modèle vs les vraies observations

model.eval()
total = 0 # Loss totale
for images, states in zip(data_images_bis, data_states_bis):
    with torch.no_grad():
        print(model(images.unsqueeze(0)),states)
        total += np.sum(np.array((model(images.unsqueeze(0))-states)**2))

print(total)