In [2]:
# ================================================================
#   PPO + Reward Model + Connect4
#   WITH MATRIX + GRAPHICAL COMBINED VIEW
#   WITH FULL-SCREEN VIDEO, PLAYER LABELS, WINNER HIGHLIGHTS
# ================================================================

# ---------- Cell 1: pip installs ----------
try:
    import subprocess, sys
    subprocess.run([sys.executable, "-m", "pip", "install", "--upgrade", "pip"], check=True)
    subprocess.run([sys.executable, "-m", "pip", "install",
                    "gymnasium==0.28.1", "torch", "torchvision", "imageio",
                    "matplotlib", "tqdm", "moviepy", "pillow"], check=True)
except Exception as e:
    print("Install warning:", e)

# ================================================================
# Imports
# ================================================================
import os
import random
import numpy as np
import imageio
from PIL import Image, ImageDraw, ImageFont

import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import trange
from IPython.display import Video, display

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

# Font
try:
    FONT = ImageFont.truetype("DejaVuSans-Bold.ttf", 18)
    FONT_LARGE = ImageFont.truetype("DejaVuSans-Bold.ttf", 24)
except:
    FONT = ImageFont.load_default()
    FONT_LARGE = ImageFont.load_default()

def text_size(font, text):
    bbox = font.getbbox(text)
    return bbox[2]-bbox[0], bbox[3]-bbox[1]

# ================================================================
# CONNECT4 ENVIRONMENT
# ================================================================
class Connect4Env:
    ROWS = 6
    COLS = 7

    def __init__(self, opponent="mixed", reward_shaping=True, render_cell_size=64):
        self.opponent = opponent
        self.reward_shaping = reward_shaping
        self.render_cell = render_cell_size

        self.board = np.zeros((self.ROWS, self.COLS), dtype=np.int8)
        self.done = False
        self.last_move = None
        self.winning_line = None

    def reset(self):
        self.board.fill(0)
        self.done = False
        self.last_move = None
        self.winning_line = None
        return self.board.copy()

    def legal_actions(self):
        return [c for c in range(self.COLS) if self.board[0,c]==0]

    def step(self, action):
        info = {}
        if self.done:
            return self.board.copy(), 0, True, info

        # illegal
        if action<0 or action>=self.COLS or self.board[0,action]!=0:
            self.done = True
            info["invalid"]=True
            return self.board.copy(), -0.3, True, info

        # agent move
        r = self._drop(action, 1)
        self.last_move=(r,action,1)

        w,line = self._check_winner()
        if w==1:
            self.done=True
            self.winning_line=line
            info["winner"]=1
            return self.board.copy(), 1.0, True, info

        if self._draw():
            self.done=True
            return self.board.copy(), 0, True, info

        # opponent move
        opp = self._opp_action()
        if opp is not None:
            r2 = self._drop(opp, -1)
            self.last_move=(r2,opp,-1)
            w2,l2 = self._check_winner()
            if w2==-1:
                self.done=True
                self.winning_line=l2
                info["winner"]=-1
                return self.board.copy(), -1.0, True, info

        if self._draw():
            self.done=True
            return self.board.copy(), 0, True, info

        return self.board.copy(), (0.01 if self.reward_shaping else 0), False, info

    def _drop(self, col, p):
        for r in range(self.ROWS-1,-1,-1):
            if self.board[r,col]==0:
                self.board[r,col]=p
                return r

    def _draw(self):
        return np.all(self.board[0]!=0)

    def _check_winner(self):
        B=self.board; R=self.ROWS; C=self.COLS

        # H
        for r in range(R):
            for c in range(C-3):
                s=B[r,c:c+4].sum()
                if s==4: return 1, [(r,c+i) for i in range(4)]
                if s==-4: return -1, [(r,c+i) for i in range(4)]

        # V
        for c in range(C):
            for r in range(R-3):
                s=B[r:r+4,c].sum()
                if s==4: return 1, [(r+i,c) for i in range(4)]
                if s==-4: return -1, [(r+i,c) for i in range(4)]

        # diag down
        for r in range(R-3):
            for c in range(C-3):
                vals=[B[r+i,c+i] for i in range(4)]
                if sum(vals)==4: return 1, [(r+i,c+i) for i in range(4)]
                if sum(vals)==-4: return -1, [(r+i,c+i) for i in range(4)]

        # diag up
        for r in range(3,R):
            for c in range(C-3):
                vals=[B[r-i,c+i] for i in range(4)]
                if sum(vals)==4: return 1, [(r-i,c+i) for i in range(4)]
                if sum(vals)==-4: return -1, [(r-i,c+i) for i in range(4)]

        return 0,None

    # mixed opponent
    def _opp_action(self):
        legal = self.legal_actions()
        if not legal: return None
        mode = random.choice(["random","heuristic"])
        if mode=="random":
            return random.choice(legal)
        return self._heuristic(legal)

    def _heuristic(self, legal):
        # win
        for c in legal:
            tmp=self.board.copy()
            self._sim(tmp,c,-1)
            if self._winner(tmp)==-1:
                return c

        # block
        for c in legal:
            tmp=self.board.copy()
            self._sim(tmp,c,1)
            if self._winner(tmp)==1:
                return c

        if 3 in legal: return 3
        return random.choice(legal)

    def _sim(self,B,c,p):
        for r in range(self.ROWS-1,-1,-1):
            if B[r,c]==0:
                B[r,c]=p
                return

    def _winner(self,B):
        R,C=self.ROWS,self.COLS
        for r in range(R):
            for c in range(C-3):
                s=B[r,c:c+4].sum()
                if s==4: return 1
                if s==-4: return -1
        return 0

    # ============ GRAPHICAL RENDER + MATRIX SIDE-BY-SIDE ============
    def render(self):
        board_img = self._render_graphical()
        matrix_img = self._render_matrix()

        # Combine horizontally
        w1,h1 = board_img.size
        w2,h2 = matrix_img.size
        H = max(h1,h2); W = w1 + w2

        combined = Image.new("RGB", (W,H), (255,255,255))
        combined.paste(board_img,(0,0))
        combined.paste(matrix_img,(w1,0))

        return np.array(combined)

    def _render_graphical(self):
        cell = self.render_cell
        H = self.ROWS*cell; W=self.COLS*cell
        canvas = np.ones((H,W,3),dtype=np.uint8)*255

        # draw
        for r in range(self.ROWS):
            for c in range(self.COLS):
                y0=r*cell; x0=c*cell
                canvas[y0:y0+cell, x0:x0+cell] = 230
                if self.board[r,c]!=0:
                    color = np.array([200,50,50]) if self.board[r,c]==1 else np.array([70,130,200])
                    cy, cx = y0+cell//2, x0+cell//2
                    rr,cc = np.ogrid[y0:y0+cell, x0:x0+cell]
                    mask=(rr-cy)**2+(cc-cx)**2 <= (cell//2-6)**2
                    region=canvas[y0:y0+cell, x0:x0+cell]
                    region[mask]=color
                    canvas[y0:y0+cell, x0:x0+cell]=region

        return Image.fromarray(canvas)

    def _render_matrix(self):
        """Render the board matrix as text image."""
        cell_h = 24
        text = "\n".join([" ".join(f"{x:2d}" for x in row) for row in self.board])
        lines = text.split("\n")

        w = 300
        h = cell_h * len(lines) + 10
        img = Image.new("RGB",(w,h),(250,250,250))
        d = ImageDraw.Draw(img)

        y=5
        for line in lines:
            d.text((5,y), line, fill=(0,0,0), font=FONT)
            y+=cell_h

        return img

# ===================================================================
# ACTOR, REWARD MODEL — SAME AS BEFORE
# ===================================================================
class ActorCritic(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv=nn.Sequential(
            nn.Conv2d(1,32,3,padding=1), nn.ReLU(),
            nn.Conv2d(32,64,3,padding=1), nn.ReLU(),
            nn.Flatten()
        )
        conv_out=self.conv(torch.zeros(1,1,6,7)).shape[1]
        self.actor=nn.Sequential(nn.Linear(conv_out,128), nn.Tanh(), nn.Linear(128,7))
        self.critic=nn.Sequential(nn.Linear(conv_out,128), nn.Tanh(), nn.Linear(128,1))

    def forward(self,x):
        if x.dim()==3: x=x.unsqueeze(1)
        f=self.conv(x.float())
        return self.actor(f), self.critic(f).squeeze(-1)

class RewardModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv=nn.Sequential(nn.Conv2d(1,16,3,padding=1),nn.ReLU(),nn.Flatten())
        conv_out=self.conv(torch.zeros(1,1,6,7)).shape[1]
        self.head=nn.Sequential(nn.Linear(conv_out+7,64),nn.ReLU(),nn.Linear(64,1))

    def forward(self,obs,aoh):
        if obs.dim()==3: obs=obs.unsqueeze(1)
        f=self.conv(obs.float())
        return self.head(torch.cat([f,aoh],dim=-1)).squeeze(-1)

# ===================================================================
# BUFFER
# ===================================================================
from collections import namedtuple
Transition=namedtuple("Transition","obs action logp reward done value")

class RolloutBuffer:
    def __init__(self): self.data=[]
    def add(self,*x): self.data.append(Transition(*x))
    def clear(self): self.data=[]
    def gae(self,last_val,gamma,lam):
        if not self.data:
            return (torch.empty(0,6,7),)*5
        obs=np.array([t.obs for t in self.data])
        act=np.array([t.action for t in self.data])
        rew=np.array([t.reward for t in self.data])
        val=np.array([t.value for t in self.data])
        done=np.array([t.done  for t in self.data])

        adv=np.zeros_like(rew)
        gae=0
        for i in reversed(range(len(rew))):
            mask=1-done[i]
            nv=last_val if i==len(rew)-1 else val[i+1]
            delta=rew[i]+gamma*nv*mask - val[i]
            gae=delta+gamma*lam*mask*gae
            adv[i]=gae

        ret=adv+val

        return (
            torch.tensor(obs).float(),
            torch.tensor(act),
            torch.tensor([t.logp for t in self.data]).float(),
            torch.tensor(ret).float(),
            torch.tensor(adv).float()
        )

# ===================================================================
# PPO TRAINER
# ===================================================================
class PPOTrainer:
    def __init__(self, env, model, reward_model):
        self.env=env
        self.model=model.to(device)
        self.rm=reward_model.to(device)

        self.opt=optim.Adam(self.model.parameters(),lr=3e-4)
        self.rm_opt=optim.Adam(self.rm.parameters(),lr=1e-3)
        self.buf=RolloutBuffer()

    def act(self,obs):
        o=torch.tensor(obs).float().unsqueeze(0).unsqueeze(0).to(device)
        with torch.no_grad():
            logits,v=self.model(o)
            probs=torch.softmax(logits,-1)
            dist=torch.distributions.Categorical(probs)
            a=dist.sample()
        return int(a.item()), float(dist.log_prob(a)), float(v.item())

    def collect(self,n=2048):
        obs=self.env.reset().astype(np.float32)
        for _ in range(n):
            a,lp,v=self.act(obs)
            no,r,d,info=self.env.step(a)
            self.buf.add(obs,a,lp,r,d,v)
            obs=no.astype(np.float32)
            if d: obs=self.env.reset().astype(np.float32)

    def train_rm(self):
        obs=torch.tensor([t.obs for t in self.buf.data]).float().to(device)
        act=torch.tensor([t.action for t in self.buf.data]).long().to(device)
        rew=torch.tensor([t.reward for t in self.buf.data]).float().to(device)

        aoh=torch.zeros(len(act),7).to(device)
        aoh[torch.arange(len(act)),act]=1.0

        ds=torch.utils.data.TensorDataset(obs,aoh,rew)
        dl=torch.utils.data.DataLoader(ds,batch_size=64,shuffle=True)

        for _ in range(3):
            for o,a,r in dl:
                pred=self.rm(o,a)
                loss=((pred-r)**2).mean()
                self.rm_opt.zero_grad()
                loss.backward()
                self.rm_opt.step()

    def update(self,gamma=0.99,lam=0.95,clip=0.2):
        last=torch.tensor(self.env.board).float().unsqueeze(0).unsqueeze(0).to(device)
        with torch.no_grad():
            _,lv=self.model(last)
            lv=lv.item()

        obs,act,oldlogp,ret,adv=self.buf.gae(lv,gamma,lam)
        self.train_rm()

        obs=obs.to(device)
        act=act.to(device)
        oldlogp=oldlogp.to(device)
        ret=ret.to(device)
        adv=(adv-adv.mean())/(adv.std()+1e-8); adv=adv.to(device)

        idx=np.arange(len(obs))
        for _ in range(4):
            np.random.shuffle(idx)
            for i in range(0,len(obs),128):
                b=idx[i:i+128]
                o=obs[b].unsqueeze(1)
                a=act[b]
                ol=oldlogp[b]
                r=ret[b]
                ad=adv[b]

                logits,v=self.model(o)
                probs=torch.softmax(logits,-1)
                dist=torch.distributions.Categorical(probs)
                nl=dist.log_prob(a)

                ratio=torch.exp(nl-ol)
                s1=ratio*ad
                s2=torch.clamp(ratio,1-clip,1+clip)*ad

                a_loss=-torch.min(s1,s2).mean()
                c_loss=(v-r).pow(2).mean()
                ent=dist.entropy().mean()

                loss=a_loss+0.5*c_loss-0.01*ent
                self.opt.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(),0.5)
                self.opt.step()

        self.buf.clear()

    def train(self,total_steps=200000,rollout=2048):
        for _ in trange(0,total_steps,rollout):
            self.collect(rollout)
            self.update()
        return self.rm

# ===================================================================
# EVAL
# ===================================================================
def evaluate(model,env,episodes=20,sampling=True):
    w=l=d=0
    for _ in range(episodes):
        o=env.reset(); done=False; lr=0
        while not done:
            obs=torch.tensor(o).float().unsqueeze(0).unsqueeze(0).to(device)
            with torch.no_grad():
                logits,_=model(obs)
                probs=torch.softmax(logits,-1).cpu().numpy()[0]
            if sampling:
                a=np.random.choice(len(probs),p=probs)
            else:
                a=int(np.argmax(probs))
            o,r,done,info=env.step(a); lr=r
        if lr>0:w+=1
        elif lr<0:l+=1
        else:d+=1
    return {"wins":w,"losses":l,"draws":d}

# ===================================================================
# OVERLAY + RECORDING
# ===================================================================
def draw_overlay(frame, header, txt):
    img=Image.fromarray(frame)
    d=ImageDraw.Draw(img,"RGBA")
    W,H=img.size

    # Header bar
    rect_h=25*len(header)
    d.rectangle([(0,0),(W,rect_h)], fill=(20,20,20,200))
    y=2
    for line in header:
        d.text((10,y), line, fill="white", font=FONT)
        y+=22

    if txt:
        w,h=text_size(FONT_LARGE,txt)
        d.rectangle([(0,rect_h),(W,rect_h+40)],fill=(0,0,0,160))
        d.text(((W-w)/2, rect_h+5), txt, fill="white", font=FONT_LARGE)

    return np.array(img)

def record_play(model, env, episodes=10, out="connect4_demo.mp4", eps=0.25):
    frames=[]
    res={"wins":0,"losses":0,"draws":0,"illegal":0}

    for ep in range(1,episodes+1):
        o=env.reset(); d=False; step=0; info={}

        while not d:
            step+=1

            # Model action distribution
            obs=torch.tensor(o).float().unsqueeze(0).unsqueeze(0).to(device)
            with torch.no_grad():
                logits,_=model(obs)
                probs=torch.softmax(logits,-1).cpu().numpy()[0]

            # ε-greedy
            if random.random()<eps:
                a=random.choice(env.legal_actions())
                actor_txt=f"Player 1 (Agent) RANDOM → column {a}"
            else:
                a=np.random.choice(len(probs),p=probs)
                actor_txt=f"Player 1 (Agent) plays column {a}"

            # Graphical + matrix frame BEFORE move
            frame=env.render()
            frame=draw_overlay(frame,[f"Episode {ep}/{episodes}", f"Step {step}"], actor_txt)
            frames.append(frame)

            # Execute agent move → environment also plays opponent
            prev_board = env.board.copy()
            o,r,d,info=env.step(a)

            # Determine who played
            if env.last_move and env.last_move[2] == -1:
                opp_col = env.last_move[1]
                step_txt=f"Player -1 (Opponent) plays column {opp_col}"
            else:
                step_txt="Move executed"

            # For terminal result
            if d:
                if "invalid" in info:
                    step_txt="Illegal move → terminal"
                elif r>0:
                    step_txt="WINNER: Player 1 (Agent)"
                elif r<0:
                    step_txt="WINNER: Opponent"
                else:
                    step_txt="DRAW"

            # Frame AFTER move
            frame2=env.render()
            frame2=draw_overlay(frame2,[f"Episode {ep}/{episodes}", f"Step {step}"], step_txt)
            frames.append(frame2)

        # count results
        if r>0: res["wins"]+=1
        elif r<0:
            if "invalid" in info: res["illegal"]+=1
            res["losses"]+=1
        else: res["draws"]+=1

    imageio.mimsave(out, frames, fps=3)
    return out, res

# ===================================================================
# MAIN
# ===================================================================
if __name__ == "__main__":
    env=Connect4Env(opponent="mixed", render_cell_size=64)
    model=ActorCritic()
    rm=RewardModel()

    trainer=PPOTrainer(env, model, rm)

    print("==== Training PPO (~200k steps) ====")
    trained_rm=trainer.train(total_steps=200000)

    print("\n==== Reward Model Returned ====")
    print(trained_rm)

    torch.save(trained_rm.state_dict(), "reward_model_final.pth")
    print("Saved reward model → reward_model_final.pth")

    print("\n==== Evaluation ====")
    print(evaluate(model, env, sampling=True))

    print("\n==== Recording Gameplay Video ====")
    vid,res=record_play(model, env, episodes=12, eps=0.25)
    print("Play results:", res)
    print("Video saved:", vid)

    # Display full-screen-capable video
    display(Video(vid, embed=True))

Using device: cuda
==== Training PPO (~200k steps) ====


100%|██████████| 98/98 [08:10<00:00,  5.00s/it]



==== Reward Model Returned ====
RewardModel(
  (conv): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Flatten(start_dim=1, end_dim=-1)
  )
  (head): Sequential(
    (0): Linear(in_features=679, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=1, bias=True)
  )
)
Saved reward model → reward_model_final.pth

==== Evaluation ====
{'wins': 20, 'losses': 0, 'draws': 0}

==== Recording Gameplay Video ====




Play results: {'wins': 11, 'losses': 1, 'draws': 0, 'illegal': 0}
Video saved: connect4_demo.mp4
