

# Training a Classification Model for Imitiation Learning

We will be using this notebook to train a classification imitation learning model **on Google Colab**. Make sure to open this ipynb on Google Colab and hit connect in the top right menu to connect to the GPU.

In [None]:
#@title ⬇️ Install core deps (PyTorch, TorchVision, OpenCV, tqdm)
!pip -q install --upgrade pip
# Colab GPUs usually support CUDA 12.x wheels below:
!pip -q install "torch==2.4.0" "torchvision==0.19.0" "torchaudio==2.4.0" --index-url https://download.pytorch.org/whl/cu121
!pip -q install opencv-python tqdm
import torch, torchvision, cv2, sys
print("Torch:", torch.__version__, "| TorchVision:", torchvision.__version__, "| OpenCV:", cv2.__version__)
print("CUDA available:", torch.cuda.is_available())


## Loading training data
Upload your training data as a zip file to your google drive. The following code will connect to your google drive and unzip the directory.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import zipfile, os, shutil

#TODO: Fill with path to training data zip
data_zip = "/content/drive/MyDrive/MAE345-2025-Sneha/imitation_data_TA_data.zip"

dset_dir = "/content/imitation_data"

# Unzip data
with zipfile.ZipFile(data_zip, "r") as zf:
    zf.extractall(dset_dir)

print("Unzipped into:", dset_dir)
!ls -R $dset_dir | head -n 30


Visualizing the data

In [None]:
import cv2, matplotlib.pyplot as plt
from pathlib import Path
import json

trial0 = sorted([d for d in Path(dset_dir).iterdir() if d.is_dir() and d.name.startswith("trial_")])[0]
e0 = json.load(open(trial0/"data_log.json"))[0]
img = cv2.cvtColor(cv2.imread(str(trial0 / e0["image_path"])), cv2.COLOR_BGR2RGB)
plt.figure(figsize=(4,4)); plt.imshow(img); plt.axis("off"); plt.title(trial0.name)
print("state:", e0["state"])
print("action:", e0["action"])


## Defining Dataloaders and Model Architecture

Here, you can import or copy over your model and dataloader. 
You should modify and optimize them for your final project to improve performance.

In [None]:
# Organize configurations

from omegaconf import OmegaConf

cfg = OmegaConf.create(
    dataset_cfg = {
        'data_dir': '/content/imitation_data', # path to dataset
        'train_trials': None, # Use auto split if None
        'val_trials': None, # Use auto split if None
        'batch_size': 32, # batch size
        'image_size': [224, 224], # height, width of image resized to
        'normalize_states': False, # whether to normalize states to zero mean and unit variance
        'normalize_actions': False, # whether to normalize actions to zero mean and unit variance
        'num_workers': 4, # number of dataloader workers
        'shuffle_train': True, # whether to shuffle training data
    },
    action_space = 'discrete', # or 'continuous'
    model_cfg = {
        'pretrained': False, # whether to use a pretrained backbone
        'action_dim': 4, # dimension of action space (vx, vy, vz, yaw_rate)
        'num_bins': 11, # number of bins to discretize action space into
        'action_low': -1.0, # lower bound of continuous action range
        'action_high': 1.0, # upper bound of continuous action range 
    }, # TODO: verify the action ranges in your dataset! 
    training_cfg = {
        'num_epochs': 50, # number of training epochs
        'lr': 0.001, # learning rate
    },
)

Helper functions to load training data from keyboard control



In [None]:
from torch.utils.data import Dataset, DataLoader

# TODO: define your dataset class here
class CrazyFlieILDataset(Dataset):
    # example dataset available in drone/datasets/dataloader.py
    pass

# TODO: create train and validation dataloaders
def create_dataloaders(cfg: dict) -> tuple[DataLoader, DataLoader]:
    # example function available in drone/datasets/dataloader.py
    pass


Define the neural network

In [None]:
import torch.nn as nn

# TODO: define your model architecture here
class DroneControlNet(nn.Module):
    # example models available in drone/models/discrete_action_model.py
    #                         and drone/models/continuous_action_model.py
    pass

## Training the model
Here, we load the dataset (which defaults to an 80/20 train–validation split), convert the continuous actions into discrete bins, and train the imitation learning model.

Make sure you are connected to GPU. This may take some time to train.

In [None]:
import torch

# After filling in the cells above, you can run your training:

# set up dataset
# TODO: parse the input cfg according to your function definitions
train_loader, val_loader = create_dataloaders(cfg.dataset_cfg)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# initialize model, loss function, optimizer
# TODO: parse the input cfg according to your model definitions
num_bins = cfg.model_cfg.num_bins
model = DroneControlNet(cfg.model_cfg).to(device)

# TODO: Loss function
def loss_fn(outputs: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
    """
    Loss function to train the model with.
    
    :param outputs: model outputs
    :param labels: ground-truth labels
    
    :return: computed loss
    """
    pass

optimizer = torch.optim.Adam(model.parameters(), lr=cfg.training_cfg.lr)


# training loop
for epoch in range(cfg.training_cfg.num_epochs):
    model.train()
    
    # TODO: implement training loop
    
    # TODO: implement validation loop




## Testing Model Inference


Here, we pick random samples from the validation set and prints both the model’s predicted actions and the ground-truth actions. View the sample images inside the `inference_samples` folder in the file-browser panel on the left.








In [None]:
import cv2
import os

save_dir = "inference_samples"
os.makedirs(save_dir, exist_ok=True)

model.eval()
count = 0
num_samples = 5
with torch.no_grad():
    # TODO: implement testing/inference
    pass


## Saving the trained model.
You will need to **download** this model to your local crazyflie code directory in order to test on the drone. Click the folder icon on the left to download the model.

In [None]:
SAVE_PATH = "/content/drone_control_model.pth" # TODO: change to prefered path
torch.save(model.state_dict(), SAVE_PATH)
print(f"✅ Model saved to {SAVE_PATH}")