In [1]:
# starcapnet.py — Multimodal Image Captioning (Image → Text)
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np, random

# --- Config ---
IMG_SIZE = 64
VOCAB = ["<PAD>", "<SOS>", "<EOS>", "This", "image", "has", "stars", "0","1","2","3","4","5","6","7","8","9"]
WORD2IDX = {w:i for i,w in enumerate(VOCAB)}
IDX2WORD = {i:w for w,i in WORD2IDX.items()}
VOCAB_SIZE = len(VOCAB)
EMB_DIM, HIDDEN_DIM = 32, 64
BATCH_SIZE, EPOCHS, LR = 32, 10, 1e-3
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Synthetic star images ---
def make_star_image(H, W, k, min_r=2, max_r=4, margin=5):
    img = np.zeros((H, W), dtype=np.float32)
    y = np.arange(H).reshape(-1,1); x = np.arange(W).reshape(1,-1)
    for _ in range(k):
        r = np.random.randint(min_r, max_r+1)
        cy = np.random.randint(margin+r, H-margin-r)
        cx = np.random.randint(margin+r, W-margin-r)
        mask = (y-cy)**2+(x-cx)**2 <= r**2
        img[mask] = 1.0
    return img

def generate_dataset(n=500):
    X, captions = [], []
    for _ in range(n):
        count = np.random.randint(0,10)
        img = make_star_image(IMG_SIZE, IMG_SIZE, count)
        caption = ["<SOS>", "This", "image", "has", str(count), "stars", "<EOS>"]
        X.append(img); captions.append(caption)
    return np.array(X), captions

def caption_to_ids(caption, max_len=7):
    ids = [WORD2IDX[w] for w in caption]
    if len(ids)<max_len: ids += [WORD2IDX["<PAD>"]] * (max_len-len(ids))
    return ids

# --- Dataset ---
class StarCaptionDataset(Dataset):
    def __init__(self, n=500):
        X, cap = generate_dataset(n)
        self.X = torch.tensor(X).unsqueeze(1).float()
        self.Y = torch.tensor([caption_to_ids(c) for c in cap])
    def __len__(self): return len(self.X)
    def __getitem__(self, idx): return self.X[idx], self.Y[idx]

# --- Model ---
class StarCapNet(nn.Module):
    def __init__(self):
        super().__init__()
        # Image encoder (CNN)
        self.cnn = nn.Sequential(
            nn.Conv2d(1,16,3,padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(16,32,3,padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Flatten(),
        )
        self.enc2feat = nn.Linear(32*(IMG_SIZE//4)*(IMG_SIZE//4), HIDDEN_DIM)

        # Text decoder (Embedding + GRU)
        self.embed = nn.Embedding(VOCAB_SIZE, EMB_DIM)
        self.gru = nn.GRU(EMB_DIM, HIDDEN_DIM, batch_first=True)
        self.fc = nn.Linear(HIDDEN_DIM, VOCAB_SIZE)

    def forward(self, images, captions):
        # Encode image
        feat = self.cnn(images)
        h0 = self.enc2feat(feat).unsqueeze(0)  # initial hidden state

        # Embed text & run GRU
        emb = self.embed(captions[:,:-1])  # input all but last token
        out,_ = self.gru(emb, h0)
        logits = self.fc(out)  # predict next tokens
        return logits

    def generate(self, image, max_len=7):
        self.eval()
        with torch.no_grad():
            feat = self.cnn(image.unsqueeze(0))
            h = self.enc2feat(feat).unsqueeze(0)
            word = torch.tensor([[WORD2IDX["<SOS>"]]], device=image.device)
            caption = []
            for _ in range(max_len):
                emb = self.embed(word)
                out,h = self.gru(emb,h)
                logits = self.fc(out[:,-1])
                word = logits.argmax(1).unsqueeze(0)
                w = IDX2WORD[word.item()]
                if w=="<EOS>": break
                caption.append(w)
            return " ".join(caption)

# --- Train ---
def train():
    ds = StarCaptionDataset(1000)
    loader = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True)
    model = StarCapNet().to(DEVICE)
    opt = torch.optim.Adam(model.parameters(), lr=LR)
    loss_fn = nn.CrossEntropyLoss(ignore_index=WORD2IDX["<PAD>"])

    for epoch in range(EPOCHS):
        model.train(); total_loss=0
        for xb,yb in loader:
            xb,yb = xb.to(DEVICE), yb.to(DEVICE)
            logits = model(xb,yb)
            loss = loss_fn(logits.reshape(-1,VOCAB_SIZE), yb[:,1:].reshape(-1))
            opt.zero_grad(); loss.backward(); opt.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}/{EPOCHS}, Loss {total_loss/len(loader):.3f}")

    # Test on a few images
    test_ds = StarCaptionDataset(5)
    for i in range(5):
        img, cap = test_ds[i]
        pred = model.generate(img.to(DEVICE))
        true = " ".join([IDX2WORD[id.item()] for id in cap if id.item()!=0])
        print(f"\nTrue: {true}\nPred: {pred}")

if __name__=="__main__":
    train()


Epoch 1/10, Loss 1.363
Epoch 2/10, Loss 0.462
Epoch 3/10, Loss 0.309
Epoch 4/10, Loss 0.253
Epoch 5/10, Loss 0.225
Epoch 6/10, Loss 0.205
Epoch 7/10, Loss 0.184
Epoch 8/10, Loss 0.162
Epoch 9/10, Loss 0.173
Epoch 10/10, Loss 0.163

True: <SOS> This image has 1 stars <EOS>
Pred: This image has 1 stars

True: <SOS> This image has 9 stars <EOS>
Pred: This image has 9 stars

True: <SOS> This image has 6 stars <EOS>
Pred: This image has 7 stars

True: <SOS> This image has 1 stars <EOS>
Pred: This image has 0 stars

True: <SOS> This image has 6 stars <EOS>
Pred: This image has 7 stars
