In [2]:
import pygame, sys, os, time
import numpy as np
import torch, torch.nn as nn, torch.optim as optim

  from pkg_resources import resource_stream, resource_exists


pygame 2.6.1 (SDL 2.28.4, Python 3.13.9)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [None]:
SCREEN_W, SCREEN_H = 640, 480
EE_R = 8
BLOCK_R = 14
TARGET_R = 18
DEMO_PATH = "demos.npy"
MODEL_PATH = "bc_imitation.pth"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class PickPlace:
    def __init__(self):
        self.reset()
    def reset(self):
        self.ee = np.array([SCREEN_W/2, SCREEN_H/2], dtype=np.float32)
        self.block = np.array([np.random.uniform(80,SCREEN_W-80), np.random.uniform(80,SCREEN_H-120)], dtype=np.float32)
        self.has_block = False
        self.target = np.array([SCREEN_W/2, 60], dtype=np.float32)
        return self.get_state()
    def get_state(self):
        st = [self.ee[0]/SCREEN_W, self.ee[1]/SCREEN_H,
              self.block[0]/SCREEN_W, self.block[1]/SCREEN_H,
              self.target[0]/SCREEN_W, self.target[1]/SCREEN_H,
              float(self.has_block)]
        return np.array(st, dtype=np.float32)
    def step(self, action):
        self.ee += action
        self.ee = np.clip(self.ee, [EE_R, EE_R], [SCREEN_W-EE_R, SCREEN_H-EE_R])
        # pick if close
        if not self.has_block and np.linalg.norm(self.ee - self.block) < (EE_R + BLOCK_R + 4):
            self.has_block = True
        # carry block
        if self.has_block:
            self.block = self.ee.copy()
        done = False
        success = False
        if self.has_block and np.linalg.norm(self.block - self.target) < (TARGET_R):
            done = True
            success = True
        return self.get_state(), done, success

# ---------- BC model ----------
class BCNet(nn.Module):
    def __init__(self, in_dim=7, out_dim=2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim,128), nn.ReLU(),
            nn.Linear(128,128), nn.ReLU(),
            nn.Linear(128,out_dim)
        )
    def forward(self,x): return self.net(x)

# ---------- utilities ----------
def load_demos(path=DEMO_PATH):
    if os.path.exists(path):
        return np.load(path, allow_pickle=True)
    return np.zeros((0,))

def save_demos(data, path=DEMO_PATH):
    np.save(path, data)




In [9]:
def main():
    pygame.init()
    screen = pygame.display.set_mode((SCREEN_W, SCREEN_H))
    pygame.display.set_caption("Imitation (BC) - press d to record, t train, p play, s save, l load")
    clock = pygame.time.Clock()

    env = PickPlace()
    state = env.reset()
    demos = []
    demos_arr = load_demos()
    if demos_arr.size:
        demos = list(demos_arr.tolist())
        print(f"Loaded {len(demos)} demo steps.")
    model = BCNet().to(DEVICE)
    opt = optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = nn.MSELoss()

    recording = False
    mode = "idle"  # 'autonomous'
    run_autonomous = False

    while True:
        for e in pygame.event.get():
            if e.type == pygame.QUIT:
                save_demos(np.array(demos, dtype=object), DEMO_PATH)
                pygame.quit(); return
            elif e.type == pygame.KEYDOWN:
                if e.key==pygame.K_d:
                    recording = not recording
                    print("Recording:", recording)
                elif e.key==pygame.K_t:
                    # train BC quickly
                    if len(demos) < 20:
                        print("Need more demo data (>=20 steps). Current:", len(demos))
                    else:
                        print("Training BC on", len(demos), "samples")
                        data = np.array(demos)
                        X = np.stack(data[:,0]).astype(np.float32)
                        Y = np.stack(data[:,1]).astype(np.float32)
                        X_t = torch.tensor(X, device=DEVICE)
                        Y_t = torch.tensor(Y, device=DEVICE)
                        for ep in range(60):
                            idx = np.random.permutation(len(X))
                            batch = 64
                            losses=[]
                            for i in range(0,len(X),batch):
                                bidx = idx[i:i+batch]
                                xb = X_t[bidx]; yb = Y_t[bidx]
                                pred = model(xb)
                                loss = loss_fn(pred, yb)
                                opt.zero_grad(); loss.backward(); opt.step()
                                losses.append(loss.item())
                            if ep%10==0:
                                print(f"Epoch {ep} loss {np.mean(losses):.4f}")
                        print("Training done.")
                elif e.key==pygame.K_p:
                    run_autonomous = not run_autonomous
                    print("Autonomous:", run_autonomous)
                elif e.key==pygame.K_r:
                    state = env.reset()
                    print("Reset env")
                elif e.key==pygame.K_s:
                    torch.save(model.state_dict(), MODEL_PATH)
                    print("Saved model")
                elif e.key==pygame.K_l:
                    if os.path.exists(MODEL_PATH):
                        model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
                        print("Loaded model")
                    else:
                        print("No model file found")
            elif e.type == pygame.MOUSEMOTION:
                if recording:
                    # if recording, interpret mouse movement as action
                    mx,my = e.pos
                    # we'll compute action as delta from ee to mouse but clipped
                    pass

        # get input / control
        mx,my = pygame.mouse.get_pos()
        mouse_buttons = pygame.mouse.get_pressed()
        # action = small delta towards mouse (teleop)
        delta = np.array([mx - env.ee[0], my - env.ee[1]], dtype=np.float32)
        # normalize step size
        max_step = 8.0
        dist = np.linalg.norm(delta)
        if dist > 1e-3:
            action = (delta / dist) * min(max_step, dist)
        else:
            action = np.array([0.0, 0.0], dtype=np.float32)

        # if autonomous: let model predict
        if run_autonomous:
            with torch.no_grad():
                s_t = torch.tensor(env.get_state(), dtype=torch.float32, device=DEVICE).unsqueeze(0)
                a_t = model(s_t).squeeze(0).cpu().numpy()
            action = a_t

        # step env
        state_before = env.get_state()
        state_after, done, success = env.step(action)
        # recording: store (state, action)
        if recording:
            demos.append((state_before, action.tolist()))
            if len(demos) % 200 == 0:
                save_demos(np.array(demos, dtype=object), DEMO_PATH)
                print("Saved demos:", len(demos))

        # render
        screen.fill((25,25,30))
        # target zone
        pygame.draw.circle(screen, (60,180,70), (int(env.target[0]), int(env.target[1])), TARGET_R)
        # block
        col = (180,80,60) if not env.has_block else (200,180,60)
        pygame.draw.circle(screen, col, (int(env.block[0]), int(env.block[1])), BLOCK_R)
        # end effector
        pygame.draw.circle(screen, (200,200,200), (int(env.ee[0]), int(env.ee[1])), EE_R)
        # overlay
        font = pygame.font.SysFont("Arial", 16)
        txt = font.render(f"Record(d):{recording}  Demos:{len(demos)}  Auton(p):{run_autonomous}", True, (240,240,240))
        screen.blit(txt, (8,8))
        if done:
            txt2 = font.render("SUCCESS! Press r to reset.", True, (255,220,60))
            screen.blit(txt2, (8,35))

        pygame.display.flip()
        clock.tick(60)

if __name__=="__main__":
    main()

Loaded 2422 demo steps.
Recording: True
Saved demos: 2600
Reset env
Saved demos: 2800
Reset env
Saved demos: 3000
Reset env
Saved demos: 3200
Reset env
Saved demos: 3400
Reset env
Saved demos: 3600
Reset env
Saved demos: 3800
Autonomous: True
Recording: False
