In [1]:
!pip install SoccerNet torch torchvision Linformer

Collecting SoccerNet
  Downloading SoccerNet-0.1.62-py3-none-any.whl.metadata (13 kB)
Collecting Linformer
  Downloading linformer-0.2.3-py3-none-any.whl.metadata (602 bytes)
Collecting scikit-video (from SoccerNet)
  Downloading scikit_video-1.1.11-py2.py3-none-any.whl.metadata (1.1 kB)
Collecting google-measurement-protocol (from SoccerNet)
  Downloading google_measurement_protocol-1.1.0-py2.py3-none-any.whl.metadata (845 bytes)
Collecting pycocoevalcap (from SoccerNet)
  Downloading pycocoevalcap-1.2-py3-none-any.whl.metadata (3.2 kB)
Collecting boto3 (from SoccerNet)
  Downloading boto3-1.40.64-py3-none-any.whl.metadata (6.6 kB)
Collecting botocore<1.41.0,>=1.40.64 (from boto3->SoccerNet)
  Downloading botocore-1.40.64-py3-none-any.whl.metadata (5.7 kB)
Collecting jmespath<2.0.0,>=0.7.1 (from boto3->SoccerNet)
  Downloading jmespath-1.0.1-py3-none-any.whl.metadata (7.6 kB)
Collecting s3transfer<0.15.0,>=0.14.0 (from boto3->SoccerNet)
  Downloading s3transfer-0.14.0-py3-none-any.whl

In [2]:
import random
import torch
from torch.utils.data import Dataset
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
import os, json, re
import numpy as np
from linformer import Linformer
from tqdm import tqdm
from collections import Counter
from SoccerNet.utils import getListGames


In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


###Dataset Organization

In [4]:
LOCAL_DIR = "/content/drive/MyDrive/SoccerNetData"

class SoccerNetCaptionDataset(Dataset):
    def __init__(self, local_dir, split="train", feature_type="1_ResNET_TF2.npy", mode="random"):
        self.local_dir = local_dir
        self.feature_type = feature_type
        self.games = getListGames(split=split)
        self.mode = mode  # "random" or "concat"

    def __len__(self):
        return len(self.games)

    def __getitem__(self, idx):
        game = self.games[idx]
        game_dir = os.path.join(self.local_dir, game)
        label_path = os.path.join(game_dir, "Labels-v2.json")
        feature_path = os.path.join(game_dir, self.feature_type)

        if (not os.path.exists(feature_path)) or (not os.path.exists(label_path)):
            dummy_feat = torch.zeros((32, 2048), dtype=torch.float32)
            return dummy_feat, "No event"

        features = np.load(feature_path)
        features = torch.from_numpy(features).float()

        caption = "No event"
        try:
            with open(label_path, "r", encoding="utf-8") as f:
                data = json.load(f)

            if "annotations" in data and len(data["annotations"]) > 0:
                if self.mode == "random":
                    ann = random.choice(data["annotations"])
                    event_label = ann.get("label", "Unknown")
                    team = ann.get("team", None)
                    minute = ann.get("gameTime", None)
                    caption = f"{event_label} by {team}" if team else event_label
                    if minute:
                        caption += f" at {minute}"

                elif self.mode == "concat":
                    events = []
                    for ann in data["annotations"][:5]:
                        label = ann.get("label", "Unknown")
                        team = ann.get("team", "")
                        minute = ann.get("gameTime", "")
                        events.append(f"{label} ({team}) at {minute}")
                    caption = " ; ".join(events)
        except Exception:
            caption = "No event"

        return features, caption

train_dataset = SoccerNetCaptionDataset(LOCAL_DIR, split="train", mode="random")
print(train_dataset[100][1])

val_dataset = SoccerNetCaptionDataset(LOCAL_DIR, split="valid", mode="random")
print(val_dataset[2][1])


Shots on target by away at 2 - 36:40
Direct free-kick by home at 1 - 37:43


### Light Encoder (Just to have the good dimensions)

In [5]:
class SoccerNetEncoder(nn.Module):
    def __init__(self, input_dim=2048, hidden_dim=512, pool_factor=4, dropout=0.1):
        super().__init__()

        self.proj = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout)
        )

        self.pool_factor = pool_factor

        self.temporal_gru = nn.GRU(hidden_dim, hidden_dim, batch_first=True, bidirectional=False)

        self.norm = nn.LayerNorm(hidden_dim)

    def forward(self, x):
        """
        x: [B, T, 2048] ‚Üí video features
        """
        x = self.proj(x)  # [B, T, H]

        if x.size(1) > self.pool_factor:
            x = x.transpose(1, 2)  # [B, H, T]
            x = F.avg_pool1d(x, kernel_size=self.pool_factor, stride=self.pool_factor)
            x = x.transpose(1, 2)  # [B, T//pool_factor, H]

        x, _ = self.temporal_gru(x)

        x = self.norm(x)

        return x


### Decoder (GRU + Linformer)

In [6]:
class CaptionDecoder(nn.Module):
    def __init__(self, hidden_dim=512, vocab_size=10000, seq_len=64, n_heads=4, depth=3, linformer_k=64):
        super().__init__()

        self.linformer = Linformer(
            dim=hidden_dim,
            seq_len=seq_len,
            depth=depth,
            heads=n_heads,
            k=linformer_k
        )

        self.gru = nn.GRU(hidden_dim, hidden_dim, batch_first=True)

        self.fc = nn.Sequential(
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, vocab_size)
        )

        self.max_seq_len = seq_len

    def forward(self, features):
        """
        features: [B, T, H]
        """
        if features.size(1) > self.max_seq_len:
            features = features[:, :self.max_seq_len, :]

        attn_out = self.linformer(features)  # [B, T, H]

        gru_out, _ = self.gru(attn_out)

        logits = self.fc(gru_out)  # [B, T, vocab_size]
        return logits


### Model : Encoder+Decoder

In [7]:
class RealTimeVideoCaptionNet(nn.Module):
    def __init__(self, input_dim=2048, hidden_dim=512, vocab_size=10000, seq_len=64):
        super().__init__()
        self.encoder = SoccerNetEncoder(input_dim, hidden_dim)
        self.decoder = CaptionDecoder(hidden_dim, vocab_size, seq_len)

    def forward(self, x):
        """
        x: [B, T, 2048]
        """
        encoded = self.encoder(x)
        captions = self.decoder(encoded)
        return captions

### Captions and Vocabulary

In [8]:
def collect_captions_from_games(local_dir, games, max_files=None):
    captions = []
    n = 0
    for game in games:
        label_path = os.path.join(local_dir, game, "Labels-v2.json")
        if not os.path.exists(label_path):
            continue

        try:
            with open(label_path, "r", encoding="utf-8") as f:
                data = json.load(f)

            if "annotations" not in data or len(data["annotations"]) == 0:
                continue

            for ann in data["annotations"]:
                lbl = ann.get("label", "").strip()
                team = ann.get("team", None)
                time = ann.get("gameTime", None)
                half = ann.get("half", None)

                if lbl:
                    caption_parts = [lbl]
                    if team:
                        caption_parts.append(f"by the {team} team")
                    """
                    if time:
                        caption_parts.append(f"at {time}")
                    if half:
                        caption_parts.append(f"in half {half}")
                    """
                    caption = " ".join(caption_parts)
                    captions.append(caption)
                    n += 1

                    if max_files is not None and n >= max_files:
                        return captions
        except Exception:
            continue

    return captions

train_games = getListGames(split="train")
print(f"üìÇ Number of games (train): {len(train_games)}")

captions_sample = collect_captions_from_games(LOCAL_DIR, train_games, max_files=8000)
print(f"üó£Ô∏è Extracted captions: {len(captions_sample)}")
print("üîç Example:", captions_sample[:5])


üìÇ Number of games (train): 300
üó£Ô∏è Extracted captions: 8000
üîç Example: ['Kick-off by the away team', 'Ball out of play by the not applicable team', 'Throw-in by the away team', 'Ball out of play by the not applicable team', 'Corner by the away team']


### Tokenizer

In [9]:
class SimpleTokenizer:
    def __init__(self, captions, vocab_size=10000):
        words = []
        for c in captions:
            words += re.findall(r"\w+", c.lower())

        most_common = Counter(words).most_common(vocab_size - 4)
        self.itos = ["<PAD>", "<SOS>", "<EOS>", "<UNK>"] + [w for w, _ in most_common]
        self.stoi = {w: i for i, w in enumerate(self.itos)}

    def encode(self, text):
        if not isinstance(text, str) or not text.strip():
            return [self.stoi["<SOS>"], self.stoi["<EOS>"]]
        toks = ["<SOS>"] + re.findall(r"\w+", text.lower()) + ["<EOS>"]
        return [self.stoi.get(t, self.stoi["<UNK>"]) for t in toks]

    def decode(self, token_list):
        words = []
        for t in token_list:
            if t in (0, 1, 2):  # skip special tokens
                continue
            if t >= len(self.itos):
                words.append("<UNK>")
            else:
                words.append(self.itos[t])
        return " ".join(words)

if len(captions_sample) == 0:
    print("‚ö†Ô∏è No caption found")
    tokenizer = SimpleTokenizer(["no event"], vocab_size=1000)
else:
    tokenizer = SimpleTokenizer(captions_sample, vocab_size=10000)

vocab_size = len(tokenizer.itos)
print(f"üßæ Total Vocabulary: {vocab_size} words")
print("üß© Examples :", tokenizer.itos[:34])

üßæ Total Vocabulary: 35 words
üß© Examples : ['<PAD>', '<SOS>', '<EOS>', '<UNK>', 'by', 'the', 'team', 'away', 'home', 'ball', 'out', 'of', 'play', 'not', 'applicable', 'throw', 'in', 'kick', 'shots', 'target', 'foul', 'free', 'clearance', 'indirect', 'off', 'on', 'corner', 'substitution', 'direct', 'card', 'yellow', 'offside', 'goal', 'penalty']


###Data Loaders

In [10]:
def collate_fn(batch):
    """
    batch: list of (features [T_i, D], caption_str)
    Retour:
      features_padded: [B, max_T, D]
      targets_padded:  [B, max_L]
    """
    features_list, token_list = [], []

    for feat, cap in batch:
        if isinstance(feat, np.ndarray):
            feat = torch.from_numpy(feat).float()
        features_list.append(feat)
        toks = torch.tensor(tokenizer.encode(cap), dtype=torch.long)
        token_list.append(toks)

    features_padded = pad_sequence(features_list, batch_first=True)
    targets_padded = pad_sequence(token_list, batch_first=True, padding_value=tokenizer.stoi["<PAD>"])

    return features_padded, targets_padded

print("üì¶ Size of train_dataset:", len(train_dataset))
print("üì¶ Size of val_dataset:", len(val_dataset))

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
val_loader   = torch.utils.data.DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn)

batch_features, batch_targets = next(iter(train_loader))
print("‚úÖ Batch features shape:", batch_features.shape)
print("‚úÖ Batch targets shape:", batch_targets.shape)
print("üí¨ Decoded example :", tokenizer.decode(batch_targets[0].tolist()))

üì¶ Size of train_dataset: 300
üì¶ Size of val_dataset: 100
‚úÖ Batch features shape: torch.Size([4, 5400, 2048])
‚úÖ Batch targets shape: torch.Size([4, 11])
üí¨ Decoded example : <UNK> <UNK>


### Training

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("üöÄ Device:", device)

model = RealTimeVideoCaptionNet(
    input_dim=2048,
    hidden_dim=512,
    vocab_size=vocab_size,
).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.AdamW(model.decoder.parameters(), lr=1e-4)

EPOCHS = 5
best_val_loss = float("inf")

for epoch in range(EPOCHS):
    model.train()
    train_loss = 0.0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Train]")

    for features, targets in pbar:
        features, targets = features.to(device), targets.to(device)
        outputs = model(features)
        min_len = min(outputs.size(1), targets.size(1))
        outputs = outputs[:, :min_len, :].reshape(-1, vocab_size)
        targets = targets[:, :min_len].reshape(-1)

        loss = criterion(outputs, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        pbar.set_postfix({"loss": f"{loss.item():.4f}"})

    avg_train_loss = train_loss / len(train_loader)

    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for features, targets in val_loader:
            features, targets = features.to(device), targets.to(device)
            outputs = model(features)
            min_len = min(outputs.size(1), targets.size(1))
            outputs = outputs[:, :min_len, :].reshape(-1, vocab_size)
            targets = targets[:, :min_len].reshape(-1)
            loss = criterion(outputs, targets)
            val_loss += loss.item()
    avg_val_loss = val_loss / len(val_loader)

    print(f"‚úÖ Epoch {epoch+1}/{EPOCHS} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")

    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(model.state_dict(), "best_decoder_linformer_gru_soccernet.pth")
        print("üíæ New best model saved !")

torch.save(model.state_dict(), "decoder_linformer_gru_soccernet_final.pth")
print("üíæ Final model : decoder_linformer_gru_soccernet_final.pth")

with open("tokenizer_vocab.json", "w", encoding="utf-8") as f:
    json.dump(tokenizer.itos, f, ensure_ascii=False, indent=2)
print("üíæ Saved vocabulary : tokenizer_vocab.json")


üöÄ Device: cuda


Epoch 1/5 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 75/75 [13:45<00:00, 11.00s/it, loss=1.6911]


‚úÖ Epoch 1/5 | Train Loss: 1.7508 | Val Loss: 1.0222
üíæ New best model saved !


Epoch 2/5 [Train]:  27%|‚ñà‚ñà‚ñã       | 20/75 [00:15<00:40,  1.36it/s, loss=1.1235]

### Model and Vocabulary Selection

In [14]:
VOCAB_PATH = "/content/tokenizer_vocab.json"
MODEL_PATH = "/content/best_decoder_linformer_gru_soccernet.pth"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("üöÄ Device:", device)

model = RealTimeVideoCaptionNet(
    input_dim=2048,
    hidden_dim=512,
    vocab_size=vocab_size,
).to(device)

if not os.path.exists(VOCAB_PATH):
    raise FileNotFoundError(f"‚ùå Vocabulary not found : {VOCAB_PATH}")

with open(VOCAB_PATH, "r") as f:
    vocab_data = json.load(f)

if isinstance(vocab_data, list):
    idx2word = {i: w for i, w in enumerate(vocab_data)}
elif isinstance(vocab_data, dict):
    if "idx2word" in vocab_data:
        idx2word = {int(k): v for k, v in vocab_data["idx2word"].items()}
    elif all(k.isdigit() for k in vocab_data.keys()):
        idx2word = {int(k): v for k, v in vocab_data.items()}
    elif "itos" in vocab_data:
        idx2word = {i: w for i, w in enumerate(vocab_data["itos"])}
    elif "vocab" in vocab_data:
        idx2word = {i: w for i, w in enumerate(vocab_data["vocab"])}
    else:
        raise ValueError("‚ùå Non Recognized Vocabulary Format")
else:
    raise ValueError("‚ùå Unvalid Vocabulary Format")

vocab_size = len(idx2word)
print(f"üìö Vocabulary is downloaded : {vocab_size} words")
print("üß© Examples :", list(idx2word.values())[:10])

if not os.path.exists(MODEL_PATH):
    raise FileNotFoundError(f"‚ùå Model not found : {MODEL_PATH}")

state_dict = torch.load(MODEL_PATH, map_location=device)
model.load_state_dict(state_dict)
model.eval()

print("‚úÖ Model and Vocabulary are successfully selected")


üöÄ Device: cuda
üìö Vocabulary is downloaded : 98 words
üß© Examples : ['<PAD>', '<SOS>', '<EOS>', '<UNK>', 'by', 'the', 'team', 'at', '2', '1']


RuntimeError: Error(s) in loading state_dict for RealTimeVideoCaptionNet:
	Missing key(s) in state_dict: "encoder.proj.0.weight", "encoder.proj.0.bias", "encoder.proj.1.weight", "encoder.proj.1.bias", "encoder.temporal_gru.weight_ih_l0", "encoder.temporal_gru.weight_hh_l0", "encoder.temporal_gru.bias_ih_l0", "encoder.temporal_gru.bias_hh_l0", "encoder.norm.weight", "encoder.norm.bias", "decoder.fc.0.weight", "decoder.fc.0.bias", "decoder.fc.1.weight", "decoder.fc.1.bias". 
	Unexpected key(s) in state_dict: "encoder.proj.weight", "encoder.proj.bias", "decoder.linformer.net.layers.3.0.fn.proj_k", "decoder.linformer.net.layers.3.0.fn.proj_v", "decoder.linformer.net.layers.3.0.fn.to_q.weight", "decoder.linformer.net.layers.3.0.fn.to_k.weight", "decoder.linformer.net.layers.3.0.fn.to_v.weight", "decoder.linformer.net.layers.3.0.fn.to_out.weight", "decoder.linformer.net.layers.3.0.fn.to_out.bias", "decoder.linformer.net.layers.3.0.norm.weight", "decoder.linformer.net.layers.3.0.norm.bias", "decoder.linformer.net.layers.3.1.fn.w1.weight", "decoder.linformer.net.layers.3.1.fn.w1.bias", "decoder.linformer.net.layers.3.1.fn.w2.weight", "decoder.linformer.net.layers.3.1.fn.w2.bias", "decoder.linformer.net.layers.3.1.norm.weight", "decoder.linformer.net.layers.3.1.norm.bias", "decoder.fc.weight", "decoder.fc.bias". 

In [None]:

# ==========================================================
# ‚öôÔ∏è Chargement des features
# ==========================================================
game_path = "/content/drive/MyDrive/SoccerNetData/england_epl/2014-2015/2015-02-21 - 18-00 Chelsea 1 - 1 Burnley"
feature_path = os.path.join(game_path, "1_ResNET_TF2.npy")

if not os.path.exists(feature_path):
    raise FileNotFoundError(f"‚ùå Fichier introuvable : {feature_path}")

print(f"üìÇ Chargement des features depuis : {feature_path}")

features = np.load(feature_path)
features = torch.tensor(features, dtype=torch.float32).unsqueeze(0).to(device)  # [1, T, 2048]

print(f"‚úÖ Features charg√©es : {features.shape}")

# ==========================================================
# üß© Segmentation en sous-s√©quences
# ==========================================================
max_len = 64
segments = []

for i in range(0, features.shape[1], max_len):
    seg = features[:, i:i+max_len, :]
    # Padding √† la fin si n√©cessaire
    if seg.shape[1] < max_len:
        pad = torch.zeros(1, max_len - seg.shape[1], features.shape[2]).to(device)
        seg = torch.cat([seg, pad], dim=1)
    segments.append(seg)

print(f"üìè Nombre de segments : {len(segments)}")

# ==========================================================
# üß† Inf√©rence
# ==========================================================
predicted_tokens = []

model.eval()
with torch.no_grad():
    for seg in segments:
        logits = model(seg)  # [1, T, vocab_size]
        preds = torch.argmax(logits, dim=-1)  # [1, T]
        predicted_tokens.extend(preds.squeeze(0).tolist())

# ==========================================================
# üó£Ô∏è Reconstruction de la l√©gende
# ==========================================================
caption_words = [
    idx2word.get(t, "<UNK>")
    for t in predicted_tokens
    if t not in [0, 1, 2, 3]  # ignore <PAD>, <SOS>, <EOS>, <UNK>
]

# Nettoyage simple : √©viter r√©p√©titions directes et tokens vides
clean_caption = []
for w in caption_words:
    if not clean_caption or w != clean_caption[-1]:
        clean_caption.append(w)

caption_text = " ".join(clean_caption).replace("<UNK>", "").strip()

print("\nüèüÔ∏è Caption g√©n√©r√©e :")
print("üó®Ô∏è", caption_text if caption_text else "(vide ou non d√©codable)")


üìÇ Chargement des features depuis : /content/drive/MyDrive/SoccerNetData/england_epl/2014-2015/2015-02-21 - 18-00 Chelsea 1 - 1 Burnley/1_ResNET_TF2.npy
‚úÖ Features charg√©es : torch.Size([1, 5400, 2048])
üìè Nombre de segments : 85

üèüÔ∏è Caption g√©n√©r√©e :
üó®Ô∏è foul by at 1 throw in by at 1 foul in by at 1 throw by at 1 throw in by at 1 throw in by at 1 throw in by at 1 throw in by at 1 ball by at 1 throw by at 1 throw by at 1 throw in by at 1 applicable throw in by at 1 throw by at 1 foul in by at 1 throw in by at 1 foul in by at 1 throw in by at 1 ball by at 1 foul in by at 1 throw in by at 1 ball by at 1 ball by at 1 throw in by at 1 foul in by at 1 ball by at 1 throw in by at 1 throw in by at 1 throw in by at 1 throw in by at 1 throw in by at 1 foul by at 1 throw in by at 1 throw in by at 1 throw by at 1 ball by at 1 throw in by at 1 throw in by at 1 throw in by at 1 throw in by at 1 throw in by at 1 throw in by at 1 throw in by at 1 throw in by at 1 throw in by at 1 f

In [None]:
print(tokenizer.itos)

['<PAD>', '<SOS>', '<EOS>', '<UNK>', 'by', 'the', 'team', 'at', '2', '1', 'away', 'home', 'ball', 'out', 'of', 'play', 'not', 'applicable', 'throw', 'in', 'kick', 'shots', 'target', 'foul', 'free', 'clearance', 'indirect', 'off', '00', 'on', 'corner', '24', '05', '04', '29', '25', '39', '03', '41', '35', '10', '43', '16', '34', '42', '09', '37', '27', '13', '14', '40', '11', '20', '02', '30', '22', '28', '21', '19', '01', '31', '08', '15', '17', '23', '12', '06', '33', '36', '18', '26', '44', '38', '32', '07', '45', 'substitution', '47', '46', 'direct', '49', '48', 'card', 'yellow', '51', '55', '53', '57', '59', '52', '58', '56', '50', 'offside', '54', 'goal', 'penalty', 'red']


In [None]:
import torch
import numpy as np
import os
import re
import json
import torch.nn.functional as F

# ==========================================================
# ‚öôÔ∏è Chargement du mod√®le et du vocabulaire
# ==========================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Charger vocabulaire ---
vocab_path = "/content/tokenizer_vocab.json"
with open(vocab_path, "r") as f:
    vocab_data = json.load(f)

# compatibilit√© : diff√©rents formats possibles
if isinstance(vocab_data, dict):
    if "idx2word" in vocab_data:
        idx2word = {int(k): v for k, v in vocab_data["idx2word"].items()}
    elif all(isinstance(k, str) for k in vocab_data.keys()):
        idx2word = {int(k): v for k, v in vocab_data.items()}
    else:
        raise ValueError("Format de vocab invalide.")
else:
    raise ValueError("Fichier vocabulaire inattendu.")

vocab_size = len(idx2word)
print(f"üìñ Vocabulaire charg√© ({vocab_size} tokens)")

# --- Charger le mod√®le ---
model = RealTimeVideoCaptionNet(
    input_dim=2048,
    hidden_dim=512,
    vocab_size=vocab_size,
    reduce_factor=100
).to(device)

model.load_state_dict(torch.load("/content/best_decoder_linformer_gru_soccernet.pth", map_location=device))
model.eval()

print("‚úÖ Mod√®le charg√© et pr√™t pour l‚Äôinf√©rence")

# ==========================================================
# ‚öΩ Chargement et r√©duction des features
# ==========================================================
game_path = "/content/drive/MyDrive/SoccerNetData/england_epl/2014-2015/2015-02-21 - 18-00 Chelsea 1 - 1 Burnley"
feature_path = os.path.join(game_path, "1_ResNET_TF2.npy")

print(f"üìÇ Chargement des features depuis : {feature_path}")
features = np.load(feature_path)
features = torch.tensor(features, dtype=torch.float32).unsqueeze(0).to(device)  # [1, T, 2048]
print("‚úÖ Features shape :", features.shape)

# --- R√©duction temporelle dynamique (au lieu de 85 segments s√©par√©s) ---
def reduce_features(x, factor=100):
    B, T, D = x.shape
    T_new = T // factor
    x = x[:, :T_new * factor, :]  # d√©coupage propre
    x = x.view(B, T_new, factor, D).mean(dim=2)  # moyenne temporelle par bloc
    return x

reduced_feats = reduce_features(features, factor=64)  # r√©duire ‚âà 64x
print("üìâ Features r√©duites :", reduced_feats.shape)

# ==========================================================
# üß† Pr√©diction avec Top-k Sampling
# ==========================================================
def top_k_sampling(logits, k=5):
    """Choisir un token via top-k sampling (plus naturel que argmax)."""
    probs = F.softmax(logits / 1.0, dim=-1)  # temp=1.0
    topk_probs, topk_indices = torch.topk(probs, k)
    sampled = torch.multinomial(topk_probs, num_samples=1)
    next_token = topk_indices.gather(-1, sampled)
    return next_token.item()

predicted_tokens = []

with torch.no_grad():
    logits = model(reduced_feats)  # [1, seq_len, vocab]
    for t in range(logits.shape[1]):
        token_id = top_k_sampling(logits[0, t], k=5)
        predicted_tokens.append(token_id)

# ==========================================================
# üßπ Nettoyage de la sortie
# ==========================================================
caption_words = [
    idx2word.get(t, "<UNK>")
    for t in predicted_tokens if t not in [0, 1, 2, 3]
]

# suppression de r√©p√©titions excessives
clean_caption = []
for w in caption_words:
    if not clean_caption or clean_caption[-1] != w:
        clean_caption.append(w)

caption_text = " ".join(clean_caption)
caption_text = re.sub(r"\s+", " ", caption_text).strip()

print("\nüèüÔ∏è Caption g√©n√©r√©e :")
print("üó®Ô∏è", caption_text)
