In [3]:
import os, io, json, random, pathlib
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms, datasets
from PIL import Image
from tqdm import tqdm
from fastapi import FastAPI, UploadFile, File
import uvicorn
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(42)
random.seed(42)

## CNN Feature Extraction

In [4]:
image_dir = "dataset/archive-2/Images"
feature_store = "artifacts/resnet50_features.pt"
os.makedirs(os.path.dirname(feature_store), exist_ok=True)
cnn = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
cnn_fc_dim = cnn.fc.in_features
cnn.fc = nn.Identity()
cnn = cnn.to(device).eval()
feat_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def extract_features(dir_path, out_path):
    paths = [str(p) for p in pathlib.Path(dir_path).rglob('*.jpg')] + [str(p) for p in pathlib.Path(dir_path).rglob('*.png')]
    feats = {}
    with torch.no_grad():
        for p in tqdm(paths):
            img = Image.open(p).convert('RGB')
            t = feat_transform(img).unsqueeze(0).to(device)
            f = cnn(t).squeeze(0).cpu()
            feats[os.path.basename(p)] = f
    torch.save({"features": feats, "dim": cnn_fc_dim}, out_path)
extract_features(image_dir, feature_store)

Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /Users/muhammadusman/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth


100%|██████████| 97.8M/97.8M [00:14<00:00, 7.04MB/s]
100%|██████████| 8091/8091 [03:21<00:00, 40.12it/s]


## Flickr8k Captioning

In [None]:
notebook_dir = os.getcwd()
dataset_root = os.path.join(notebook_dir, "dataset", "archive-2")
flickr_root = dataset_root
flickr_images = os.path.join(flickr_root, "Images")
flickr_captions = os.path.join(flickr_root, "captions.txt")

# Debug: print the paths to verify
print(f"Looking for images at: {flickr_images}")
print(f"Directory exists: {os.path.isdir(flickr_images)}")

# quick path checks to surface path issues early
if not os.path.isdir(flickr_images):
    raise FileNotFoundError(f"Image folder not found at {flickr_images}")
if not os.path.isfile(flickr_captions):
    raise FileNotFoundError(f"Captions file not found at {flickr_captions}")
min_freq = 2
max_len = 20
pad_token, start_token, end_token, unk_token = "<pad>", "<start>", "<end>", "<unk>"
class Vocabulary:
    def __init__(self, min_freq=1):
        self.min_freq = min_freq
        self.freqs = {}
        self.stoi = {pad_token:0, start_token:1, end_token:2, unk_token:3}
        self.itos = [pad_token, start_token, end_token, unk_token]
    def build(self, lines):
        for line in lines:
            for w in line.split():
                self.freqs[w] = self.freqs.get(w,0)+1
        for w,f in self.freqs.items():
            if f>=self.min_freq and w not in self.stoi:
                self.stoi[w]=len(self.itos)
                self.itos.append(w)
    def encode(self, text):
        tokens=[start_token]+[w if w in self.stoi else unk_token for w in text.split()][:max_len-2]+[end_token]
        ids=[self.stoi.get(t, self.stoi[unk_token]) for t in tokens]
        if len(ids)<max_len:
            ids+= [self.stoi[pad_token]]*(max_len-len(ids))
        return torch.tensor(ids, dtype=torch.long)
    def decode(self, ids):
        words=[]
        for i in ids:
            w=self.itos[i]
            if w==end_token:
                break
            if w not in {start_token,pad_token}:
                words.append(w)
        return " ".join(words)
class Flickr8kDataset(Dataset):
    def __init__(self, img_dir, caption_file, vocab, transform):
        with open(caption_file) as f:
            lines = [line.strip() for line in f if len(line.strip())>0]
            # Skip header line
            raw=[line.split(',', 1) for line in lines[1:]]
        self.data=[(parts[0], parts[1]) for parts in raw if len(parts)==2]
        vocab.build([c for _,c in self.data])
        self.vocab=vocab
        self.dir=img_dir
        self.transform=transform
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        fname,cap=self.data[idx]
        path=os.path.join(self.dir,fname)
        img=Image.open(path).convert('RGB')
        img=self.transform(img)
        cap_ids=self.vocab.encode(cap)
        return img, cap_ids
caption_transform=transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])
vocab=Vocabulary(min_freq=min_freq)
train_ds=Flickr8kDataset(flickr_images, flickr_captions, vocab, caption_transform)
def collate(batch):
    imgs=torch.stack([b[0] for b in batch])
    caps=torch.stack([b[1] for b in batch])
    return imgs,caps
train_loader=DataLoader(train_ds, batch_size=32, shuffle=True, collate_fn=collate, num_workers=0)
vocab_size=len(vocab.itos)
embed_dim=256
hidden_dim=512
class CaptionModel(nn.Module):
    def __init__(self, embed_dim, hidden_dim, vocab_size, feature_dim):
        super().__init__()
        self.embed=nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.lstm=nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        self.init_fc=nn.Linear(feature_dim, embed_dim)
        self.fc=nn.Linear(hidden_dim, vocab_size)
    def forward(self, features, captions):
        emb=self.embed(captions[:,:-1])
        init=self.init_fc(features).unsqueeze(1)
        inputs=torch.cat([init, emb], dim=1)
        out,_=self.lstm(inputs)
        logits=self.fc(out)
        return logits
    def generate(self, feature, max_len=20):
        seq=[1]
        hidden=None
        x=self.init_fc(feature).unsqueeze(0).unsqueeze(1)
        for _ in range(max_len):
            out,hidden=self.lstm(x,hidden)
            logits=self.fc(out[:, -1])
            next_id=logits.argmax(dim=-1).item()
            seq.append(next_id)
            if next_id==2:
                break
            x=self.embed(torch.tensor([[next_id]], device=feature.device))
        return seq
caption_model=CaptionModel(embed_dim, hidden_dim, vocab_size, cnn_fc_dim).to(device)
opt=optim.Adam(caption_model.parameters(), lr=1e-4)
criterion=nn.CrossEntropyLoss(ignore_index=0)
caption_epochs=5
for epoch in range(caption_epochs):
    caption_model.train()
    total_loss=0.0
    for imgs,caps in tqdm(train_loader):
        imgs,caps=imgs.to(device),caps.to(device)
        with torch.no_grad():
            feats=cnn(imgs)
        logits=caption_model(feats, caps)
        loss=criterion(logits[:,1:,:].reshape(-1, vocab_size), caps[:,1:].reshape(-1))
        opt.zero_grad()
        loss.backward()
        opt.step()
        total_loss+=loss.item()
    torch.save({"model":caption_model.state_dict(), "vocab":vocab.itos}, f"artifacts/caption_epoch_{epoch+1}.pt")
    print(f"epoch {epoch+1} loss {total_loss/len(train_loader):.4f}")

Looking for images at: /Users/muhammadusman/Desktop/dl/dataset/archive-2/Images
Directory exists: True


  0%|          | 0/1265 [00:03<?, ?it/s]


ValueError: Expected input batch_size (640) to match target batch_size (608).

In [None]:
caption_model.eval()
sample_path=os.path.join(flickr_images, os.listdir(flickr_images)[0])
img=Image.open(sample_path).convert('RGB')
t=caption_transform(img).unsqueeze(0).to(device)
with torch.no_grad():
    f=cnn(t)
    ids=caption_model.generate(f.squeeze(0))
text=vocab.decode(ids[1:])
print(text)

## Stanford Actions Classification

In [None]:
stanford_root="data/stanford_actions"
train_dir=os.path.join(stanford_root, "train")
val_dir=os.path.join(stanford_root, "val")
action_transform=transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])
train_ds_act=datasets.ImageFolder(train_dir, transform=action_transform)
val_ds_act=datasets.ImageFolder(val_dir, transform=action_transform)
train_loader_act=DataLoader(train_ds_act, batch_size=32, shuffle=True, num_workers=2)
val_loader_act=DataLoader(val_ds_act, batch_size=32, shuffle=False, num_workers=2)
num_classes=len(train_ds_act.classes)
action_model=models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
action_model.fc=nn.Linear(action_model.fc.in_features, num_classes)
action_model=action_model.to(device)
opt_act=optim.Adam(action_model.parameters(), lr=3e-4)
criterion_act=nn.CrossEntropyLoss()
act_epochs=5
for epoch in range(act_epochs):
    action_model.train()
    train_loss=0.0
    for imgs,labels in tqdm(train_loader_act):
        imgs,labels=imgs.to(device),labels.to(device)
        logits=action_model(imgs)
        loss=criterion_act(logits, labels)
        opt_act.zero_grad()
        loss.backward()
        opt_act.step()
        train_loss+=loss.item()
    action_model.eval()
    correct=0
    total=0
    with torch.no_grad():
        for imgs,labels in val_loader_act:
            imgs,labels=imgs.to(device),labels.to(device)
            preds=action_model(imgs).argmax(dim=1)
            correct+= (preds==labels).sum().item()
            total+= labels.numel()
    acc=correct/total if total>0 else 0.0
    torch.save(action_model.state_dict(), f"artifacts/action_epoch_{epoch+1}.pt")
    print(f"epoch {epoch+1} loss {train_loss/len(train_loader_act):.4f} val_acc {acc:.3f}")

## FastAPI Inference

In [None]:
caption_ckpt="artifacts/caption_epoch_5.pt"
action_ckpt="artifacts/action_epoch_5.pt"
cnn.eval()
caption_model.eval()
action_model.eval()
if os.path.exists(caption_ckpt):
    payload=torch.load(caption_ckpt, map_location=device)
    caption_model.load_state_dict(payload["model"])
    vocab.itos=payload["vocab"]
    vocab.stoi={w:i for i,w in enumerate(vocab.itos)}
if os.path.exists(action_ckpt):
    action_model.load_state_dict(torch.load(action_ckpt, map_location=device))
app=FastAPI()
def predict_caption(img_tensor):
    with torch.no_grad():
        f=cnn(img_tensor).squeeze(0)
        ids=caption_model.generate(f)
        return vocab.decode(ids[1:])
def predict_action(img_tensor):
    with torch.no_grad():
        logits=action_model(img_tensor)
        idx=logits.argmax(dim=1).item()
        return train_ds_act.classes[idx]
def load_image(file_bytes):
    img=Image.open(io.BytesIO(file_bytes)).convert('RGB')
    t=caption_transform(img).unsqueeze(0).to(device)
    t_act=action_transform(img).unsqueeze(0).to(device)
    return t, t_act
@app.post('/predict')
async def predict(file: UploadFile = File(...)):
    data=await file.read()
    t, t_act=load_image(data)
    cap=predict_caption(t)
    act=predict_action(t_act)
    return {"caption": cap, "action": act}
# if __name__ == '__main__':
#     uvicorn.run(app, host='0.0.0.0', port=8000)