In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.models as models
import numpy as np
import random
from itertools import chain, combinations
from sklearn.metrics import average_precision_score
from sklearn.model_selection import train_test_split
from pathlib import Path
import threading
from multiprocessing import Manager

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(device)

cuda:0


Handling Data

In [2]:
class ShardedEpisodicDataset():
  def __init__(self, data_paths, meta_path, min_n_way, try_k_shot, try_n_query, split="train", test_split=0.1, val_split=0.1, split_seed=33):
    self.data_paths = data_paths
    self.min_n_way = min_n_way
    self.try_k_shot = try_k_shot
    self.try_n_query = try_n_query

    # load metadata
    meta = torch.load(meta_path)
    self.rel_idxs = meta['rel_idxs']
    self.n_examples = self.rel_idxs[-1] + 1
    meta_class_idxs = meta['class_idxs']
    self.n_classes = len(meta_class_idxs)

    # filter classes based on split type
    all_idxs = list(range(self.n_examples))
    train_idxs, test_idxs = train_test_split(all_idxs, test_size=test_split + val_split, random_state=split_seed)
    test_idxs, val_idxs = train_test_split(test_idxs, test_size=val_split / (test_split + val_split), random_state=split_seed)
    print(f"train_size : {len(train_idxs)}, test_size : {len(test_idxs)}, val_size : {len(val_idxs)}")
    if split == "train":
      valid_idxs = set(train_idxs)
    elif split == "test":
      valid_idxs = set(test_idxs)
    elif split == 'val':
      valid_idxs = set(val_idxs)
    else:
      raise ValueError("Critical: split is not train, test or val")

    self.class_idxs = [[i.item() for i in meta_class_idxs[c] if i.item() in valid_idxs] for c in range(self.n_classes)]

    # filter valid classes
    self.valid_classes = [
      c for c, indices in enumerate(self.class_idxs)
      if len(indices) >= self.try_k_shot + self.try_n_query
    ]

    print(f"{len(self.valid_classes)} of {len(self.class_idxs)} classes have atleast try_k_shot + try_n_query ({self.try_k_shot + self.try_n_query}) examples")

    if len(self.valid_classes) < self.min_n_way:
      raise ValueError(f"critical: try_k_shot + try_n_query is too big or min_n_way is too big, there are {len(self.valid_classes)} valid classes")

    # cache management
    self.cache_size = 1
    self.manager = Manager()
    self.cache = self.manager.dict()
    self.lock = threading.Lock()

  def __len__(self):
    return 1 # cuz this number sounds good

  def _get_idxs(self, idx):
    # binary search on shard to find shard
    low, high = 0, len(self.rel_idxs) - 1
    while (low < high):
      mid = (low + high) // 2
      if self.rel_idxs[mid] < idx:
        low = mid + 1
      else:
        high = mid
    shard_idx = low
    
    if (shard_idx == 0):
      abs_idx = idx
    else:
      abs_idx = idx - self.rel_idxs[shard_idx - 1] - 1
    return shard_idx, abs_idx

  def _load_shard(self, shard_idx):
    shard_path = self.data_paths[shard_idx]
    
    with self.lock:
      # Check if shard is already in cache
      if shard_path in self.cache:
        return self.cache[shard_path]
            
      # If cache is full, remove one item
      if len(self.cache) >= self.cache_size:
        # More predictable eviction strategy
        oldest_key = next(iter(self.cache))
        self.cache.pop(oldest_key)
            
      # Load the shard and store in cache
      shard = torch.load(shard_path)
      self.cache[shard_path] = shard
        
    return shard

  def __getitem__(self, idx):
    # Choose random subset of n_way classes
    selected_classes = random.sample(self.valid_classes, self.min_n_way)

    support_indices, query_indices = [], []

    for c in selected_classes:
      indices = self.class_idxs[c]
      if not indices:
        continue
      selected_indices = random.sample(indices, self.try_k_shot + self.try_n_query)
      support_indices.extend(selected_indices[:self.try_k_shot])
      query_indices.extend(selected_indices[self.try_k_shot:])

    support_indices = torch.tensor(support_indices, dtype=torch.long)
    query_indices = torch.tensor(query_indices, dtype=torch.long)

    unique_shards = set(self._get_idxs(i)[0] for i in support_indices.tolist() + query_indices.tolist())
    shard_data = {shard_idx: self._load_shard(shard_idx) for shard_idx in unique_shards}

    # Fetch data/labels from those shards
    support_data = torch.cat([shard_data[self._get_idxs(i)[0]]['imgs'][self._get_idxs(i)[1]].unsqueeze(0) for i in support_indices])
    support_label = torch.cat([shard_data[self._get_idxs(i)[0]]['lbls'][self._get_idxs(i)[1]].unsqueeze(0) for i in support_indices])
    query_data = torch.cat([shard_data[self._get_idxs(i)[0]]['imgs'][self._get_idxs(i)[1]].unsqueeze(0) for i in query_indices])
    query_label = torch.cat([shard_data[self._get_idxs(i)[0]]['lbls'][self._get_idxs(i)[1]].unsqueeze(0) for i in query_indices])

    return {
      'support_data': support_data, # (~n_way * ~k_shot, C, H, W)
      'support_label': support_label, # (~n_way * ~k_shot, num_classes)
      'query_data': query_data, # (~n_way * ~n_query, C, H, W)
      'query_label': query_label # (~n_way * ~n_query, num_classes)
    }

In [None]:
data_paths = list(Path('./processed').rglob('*.pth'))
data_paths.sort()
meta_path = data_paths.pop()
meta = torch.load(meta_path)

  meta = torch.load(meta_path)


In [4]:
dataset = ShardedEpisodicDataset(data_paths, meta_path, 2, 2, 2, split="train")
val_dataset = ShardedEpisodicDataset(data_paths, meta_path, 2, 2, 2, split="val")
test_dataset = ShardedEpisodicDataset(data_paths, meta_path, 2, 2, 2, split="test")

train_size : 800, test_size : 100, val_size : 100
88 of 200 classes have atleast try_k_shot + try_n_query (4) examples


  meta = torch.load(meta_path)


train_size : 800, test_size : 100, val_size : 100
16 of 200 classes have atleast try_k_shot + try_n_query (4) examples
train_size : 800, test_size : 100, val_size : 100
12 of 200 classes have atleast try_k_shot + try_n_query (4) examples


CNN stuff

In [5]:
def l2_regularization(model, lambda_l2=1e-4):
    l2_norm = 0
    for param in model.parameters():
        l2_norm += param.pow(2).sum()
    return lambda_l2 * l2_norm

In [6]:
class ConvEncoder(nn.Module):
  def __init__(self):
    super().__init__()
    resnet = models.resnet18(weights=True)
    resnet = nn.Sequential(*list(resnet.children())[:-2])
    self.encoder = resnet

  def forward(self, x):
    x = self.encoder(x) # (batch_size, 512, 7, 7) for input of (batch_size, 3, 224, 224)
    return x

ProtoNet stuff

In [7]:
def euclidean_dist(x, y):
  n_class, n_sample, n_feature = y.shape[0], x.shape[0], x.shape[1]

  x = x.unsqueeze(1).expand(-1, n_class, -1)
  y = y.unsqueeze(0).expand(n_sample, -1, -1)

  return torch.pow(x - y, 2).sum(2)

In [8]:
class FocalLoss(nn.Module):
  def __init__(self, gamma=2.0, alpha=None, reduction='mean'):
    super().__init__()
    self.gamma = gamma
    self.alpha = alpha
    self.reduction = reduction

  def forward(self, logits, targets):
    bce_loss = F.binary_cross_entropy_with_logits(logits, targets, reduction='none')
    pt = torch.exp(-bce_loss)
    focal_loss = ((1 - pt) ** self.gamma) * bce_loss
    if self.alpha is not None:
      focal_loss = self.alpha * focal_loss
    if self.reduction == 'mean':
      return focal_loss.mean()
    elif self.reduction == 'sum':
      return focal_loss.sum()
    else:
      return focal_loss

In [9]:
def compute_mAP(y_true, y_scores):
  num_classes = y_true.shape[1]
  APs = []
  for i in range(num_classes):
    if np.sum(y_true[:, i]) >0:
      AP = average_precision_score(y_true[:, i], y_scores[:, i])
      APs.append(AP)
  return np.mean(APs) if len(APs) > 0 else 0.0

In [None]:
def prototypical_loss(model, episode, device, use_focal=True, gamma=2.0, alpha=1.0):
  model.train()
  model.to(device)
  support_label = episode['support_label'].to(device)
  query_label = episode['query_label'].to(device)
  support_data = episode['support_data'].to(device)
  query_data = episode['query_data'].to(device)

  # create extended vector of labels for each support and query
  num_classes = episode['support_label'].shape[1]

  unique_combinations = set()
  for label_set in support_label:
      active = tuple(i for i, val in enumerate(label_set) if val == 1)
      if active:
          unique_combinations.add(active)
  for label_set in query_label:
      active = tuple(i for i, val in enumerate(label_set) if val == 1)
      if active:
          unique_combinations.add(active)

  def get_power_set(indices):
    s = list(indices)
    return list(chain.from_iterable(combinations(s, r) for r in range(1, len(s) + 1)))

  all_subsets = set()
  for idx_comb in unique_combinations:
    for subsets in get_power_set(idx_comb):
      all_subsets.add(tuple(sorted(subsets)))

  all_subsets = sorted(all_subsets, key=lambda s: (len(s), s))
  subset_map = {subset: i for i, subset in enumerate(all_subsets)} # subset tuple to idx in extended multi hot encoded

  def extend_labels(labels):
    ext = torch.zeros((len(labels), len(all_subsets)), dtype=torch.float, device=device)
    for i, row in enumerate(labels):
      active_indices = tuple(sorted(j for j, val in enumerate(row) if val == 1))
      if active_indices:
        for subset in get_power_set(active_indices):
          subset = tuple(sorted(subset))
          ext[i, subset_map[subset]] = 1
    return ext

  support_ext_lbls = extend_labels(support_label)
  query_ext_lbls = extend_labels(query_label)

  # calculate prototype for each label combination
  """
  support_embeddings = []
  for x in support_data:
    x = x.unsqueeze(0)
    embed = model(x)
    support_embeddings.append(embed.cpu())
    del x, embed  # Manually delete tensors
    torch.cuda.empty_cache()  # Clear CUDA cache
  support_embeddings = torch.cat(support_embeddings, dim=0)  # Stack results

  query_embeddings = []
  for x in query_data:
    x = x.unsqueeze(0)
    embed = model(x)
    query_embeddings.append(embed.cpu())
    del x, embed  # Manually delete tensors
    torch.cuda.empty_cache()  # Clear CUDA cache
  query_embeddings = torch.cat(query_embeddings, dim=0)  # Stack results
  """
  support_embeddings = model(support_data)
  query_embeddings = model(query_data)

  d = support_embeddings.shape[1]
  num_proto = len(all_subsets)
  prototypes = torch.zeros((num_proto, d), device=device)
  counts = torch.zeros(num_proto, device=device)

  for i in range(support_embeddings.shape[0]):
      active_prototypes = support_ext_lbls[i].nonzero(as_tuple=True)[0]
      for p in active_prototypes:
          prototypes[p] += support_embeddings[i]
          counts[p] += 1
  counts = counts.clamp(min=1)
  prototypes /= counts.unsqueeze(1)  # Compute mean embeddings

  # Compute distances between query embeddings and prototypes
  distances = euclidean_dist(query_embeddings, prototypes)  # Shape: (num_query, num_subsets)
  logits = -distances

  if use_focal:
    pos_freq = query_ext_lbls.float().mean(dim=0)
    epsilon = 1e-6
    raw_pos_weight = 1.0 / (pos_freq + epsilon)
    alpha = torch.clamp(raw_pos_weight, max=10.0)
    loss_fn = FocalLoss(gamma=gamma, alpha=torch.tensor(alpha, device=device) if isinstance(alpha, (list, np.ndarray)) else alpha)
  else:
    loss_fn = nn.BCEWithLogitsLoss()
  loss = loss_fn(logits, query_ext_lbls)

  with torch.no_grad():
    probs = torch.sigmoid(logits)
    y_true = query_ext_lbls.cpu().numpy()
    y_scores = probs.cpu().numpy()
    mAP = compute_mAP(y_true, y_scores)

  pred = (probs > 0.5).float()
  true_bool = query_ext_lbls.bool()
  tp = (pred * true_bool.float()).sum().float()
  fp = (pred * (~true_bool).float()).sum().float()
  fn = ((~pred.bool()) * true_bool.float()).sum().float()
  precision = tp / (tp + fp + 1e-6)
  recall = tp / (tp + fn + 1e-6)
  f1_score = 2 * (precision * recall) / (precision + recall + 1e-6)

  return loss, f1_score, mAP

  # # Apply sigmoid and compute loss using BCE
  # #print(distances.shape)
  # #print(query_ext_lbls.float().shape)
  # pos_freq = query_ext_lbls.float().mean(dim=0)
  # epsilon = 1e-6
  # raw_pos_weight = 1.0 / (pos_freq + epsilon)
  # max_weight = 10.0
  # pos_weight  = torch.clamp(raw_pos_weight, max = max_weight)
  # loss_fn = nn.BCEWithLogitsLoss(pos_weight = pos_weight)
  # loss = loss_fn(distances, query_ext_lbls.float())

  # # compress the extended prediction back to normal
  # pred = torch.zeros(episode['query_data'].shape[0], num_classes, dtype=torch.bool, device=device)
  # for i in range(episode['query_data'].shape[0]):
  #   min_dist = torch.min(distances[i])  # 1. Get smallest distance value for this query
  #   candidates = (distances[i] == min_dist).nonzero(as_tuple=True)[0]  # 2. Find ALL prototypes with this distance
  #   best_idx = max(candidates, key=lambda idx: len(all_subsets[idx]))  # 3. Select largest subset
  #   pred[i, all_subsets[best_idx]] = True

  # true_bool = episode['query_label'].to(device).bool()
  # tp = (pred & true_bool).sum().float()
  # fp = (pred & ~true_bool).sum().float()
  # fn = (~pred & true_bool).sum().float()

  # precision = tp / (tp + fp + 1e-6)  # Adding small value to avoid division by zero
  # recall = tp / (tp + fn + 1e-6)
  # f1_score = 2 * (precision * recall) / (precision + recall + 1e-6)

  # return loss, f1_score

In [11]:
def predict(model, episode, device):
    model.eval()
    model.to(device)

    num_classes = episode['support_label'].shape[1]
    unique_combinations = set()
    for label_set in episode['support_label']:
        active = tuple(i for i, val in enumerate(label_set) if val == 1)
        if active:
            unique_combinations.add(active)
    for label_set in episode['query_label']:
        active = tuple(i for i, val in enumerate(label_set) if val == 1)
        if active:
            unique_combinations.add(active)

    def get_power_set(indices):
        return list(chain.from_iterable(combinations(indices, r) for r in range(1, len(indices) + 1)))

    all_subsets = set()
    for idx_comb in unique_combinations:
        for subset in get_power_set(idx_comb):
            all_subsets.add(tuple(sorted(subset)))
    all_subsets = sorted(all_subsets, key=lambda s: (len(s), s))
    subset_map = {subset: i for i, subset in enumerate(all_subsets)}

    def extend_labels(labels):
        ext = torch.zeros((len(labels), len(all_subsets)), dtype=torch.float, device=device)
        for i, row in enumerate(labels):
            active_indices = tuple(sorted(j for j, val in enumerate(row) if val == 1))
            if active_indices:
                for subset in get_power_set(active_indices):
                    subset = tuple(sorted(subset))
                    ext[i, subset_map[subset]] = 1.0
        return ext

    support_ext_lbls = extend_labels(episode['support_label'])
    query_ext_lbls = extend_labels(episode['query_label'])

    with torch.no_grad():
        support_embeddings = model(episode['support_data'].to(device))
        query_embeddings = model(episode['query_data'].to(device))
    d = support_embeddings.shape[1]
    num_proto = len(all_subsets)
    prototypes = torch.zeros((num_proto, d), device=device)
    counts = torch.zeros(num_proto, device=device)
    for i in range(support_embeddings.shape[0]):
        active_prototypes = support_ext_lbls[i].nonzero(as_tuple=True)[0]
        for p in active_prototypes:
            prototypes[p] += support_embeddings[i]
            counts[p] += 1
    counts = counts.clamp(min=1)
    prototypes /= counts.unsqueeze(1)

    distances = euclidean_dist(query_embeddings, prototypes)
    logits = -distances  # Convert distances to logits
    probs = torch.sigmoid(logits)
    pred = (probs > 0.5).int()

    return pred, probs

GAT stuff

In [12]:
class GATLayer(nn.Module):
    def __init__(self, c_in, c_out, num_heads=1, concat_heads=True, alpha=0.3, top_k=5):
        super().__init__()
        self.num_heads = num_heads
        self.concat_heads = concat_heads
        self.top_k = top_k

        if self.concat_heads:
            assert c_out % num_heads == 0, "c_out must be divisible by num_heads"
            c_out = c_out // num_heads

        self.projection = nn.Linear(c_in, c_out * num_heads)
        self.a = nn.Parameter(torch.Tensor(num_heads, 2 * c_out))
        self.leakyrelu = nn.LeakyReLU(alpha)

        nn.init.xavier_uniform_(self.projection.weight, gain=1.414)
        nn.init.xavier_uniform_(self.a, gain=1.414)

    def compute_adj_matrix(self, node_feats, drop_prob=0.2):
        # Normalize features along the last dimension (using lowercase 'normalize')
        norm_feats = F.normalize(node_feats, p=2, dim=-1)
        # Here, node_feats has shape (batch_size, num_nodes, feature_dim).
        # We need to compute pairwise similarity per batch.
        # Use torch.bmm to perform batched matrix multiplication.
        sim = torch.bmm(norm_feats, norm_feats.transpose(1, 2))  # (batch, num_nodes, num_nodes)
        # Create a binary adjacency matrix based on a threshold (e.g., 0.5)
        if self.top_k is not None:
          B, N, _ = sim.shape
          topk_values, top_k_indices = torch.topk(sim, k=self.top_k, dim=-1)
          mask = torch.zeros_like(sim)
          mask.scatter_(2, top_k_indices, 1.0)
          sim = sim * mask

        adj_matrix = (sim > 0.4).float()

        if drop_prob > 0.0:
          drop_mask = (torch.rand_like(adj_matrix) > drop_prob).float()
          adj_matrix = adj_matrix * drop_mask

        return adj_matrix

    def forward(self, node_feats, print_attn_probs=False):
        batch_size, num_nodes, _ = node_feats.shape

        # Compute the adjacency matrix from node features.
        adj_matrix = self.compute_adj_matrix(node_feats, drop_prob=0.2)  # Expected shape: (batch_size, num_nodes, num_nodes)

        node_feats = self.projection(node_feats)
        node_feats = node_feats.view(batch_size, num_nodes, self.num_heads, -1)

        # Now, get the indices of nonzero elements in the adj matrix.
        # Note: For a batched adj_matrix, nonzero() returns indices with shape (num_edges, 3)
        edges = (adj_matrix > 0.0).nonzero(as_tuple=False)
        # edges[:, 0] is the batch index, edges[:, 1] is the row index, and edges[:, 2] is the column index.
        batch_indices = edges[:, 0]
        offset = batch_indices * num_nodes
        edges_indices_row = offset + edges[:, 1]
        edges_indices_col = offset + edges[:, 2]

        node_feats_flat = node_feats.view(batch_size * num_nodes, self.num_heads, -1)

        a_input = torch.cat([
            torch.index_select(node_feats_flat, dim=0, index=edges_indices_row),
            torch.index_select(node_feats_flat, dim=0, index=edges_indices_col)
        ], dim=-1)

        attn_logits = torch.einsum('bhc,hc->bh', a_input, self.a)
        attn_logits = self.leakyrelu(attn_logits)

        # Create an attention matrix with shape (batch_size, num_nodes, num_nodes, num_heads)
        attn_matrix = torch.full((*adj_matrix.shape, self.num_heads), -1e9, device=node_feats.device)

        # Assign computed logits to positions where adj_matrix is 1.
        # We expand adj_matrix to have a head dimension.
        attn_matrix[(adj_matrix > 0.0).unsqueeze(-1).expand(-1, -1, -1, self.num_heads)] = attn_logits.reshape(-1)

        attn_probs = F.softmax(attn_matrix, dim=2)

        if print_attn_probs:
            print("attention probs \n", attn_probs.permute(0, 3, 1, 2).detach().cpu())

        node_feats = torch.einsum('bijh,bjhc->bihc', attn_probs, node_feats)

        if self.concat_heads:
            node_feats = node_feats.reshape(batch_size, num_nodes, -1)
        else:
            node_feats = node_feats.mean(dim=2)

        return node_feats

Combining all the models

In [13]:
class CnnGat(nn.Module):
  def __init__(self, cnn_encoder, gat_layer, embed_size, device, cnn_embed_size=512):
    super().__init__()
    self.device = device
    self.embed_size = embed_size
    self.cnn_encoder = cnn_encoder()
    self.gat_layer = gat_layer(
      c_in=cnn_embed_size,
      c_out=self.embed_size // 49, # there are 7x7 nodes from the cnn, so we calc how many features each node can have to have embed_size size when flattened
    )
    self.cnn_encoder.to(device)
    self.gat_layer.to(device)

  def forward(self, data):
    cnn_embeds = self.cnn_encoder(data) # (len data, cnn_embed_size, H, W) h, w = 7, 7
    B, C, H, W = cnn_embeds.shape
    nodes = cnn_embeds.permute(0, 2, 3, 1).reshape(B, H*W, C)
    nodes = self.gat_layer(nodes) # (len data, h * w, embedsize / 49)
    gat_embeds = nodes.view(nodes.size(0), -1) # (len data, embed size)
    return gat_embeds

Training Stuff

In [14]:
from tqdm.auto import tqdm
import torch.optim.lr_scheduler as lr_scheduler
import random
def check_val(model, num_checks=10):
    running_f1 = 0.0
    min_f1 = float('inf')
    max_f1 = float('-inf')
    running_map = 0.0
    for i in range(num_checks):
        test_ep = val_dataset[0]
        _, f1, mAP = prototypical_loss(model, test_ep, device, use_focal=True, gamma=2.0, alpha=1.0)
        f1_val = f1.item() if torch.is_tensor(f1) else f1
        running_f1 += f1_val
        running_map += mAP
        min_f1 = min(min_f1, f1_val)
        max_f1 = max(max_f1, f1_val)

    avg_f1 = running_f1 / num_checks
    avg_map = running_map / num_checks
    print(f"Val: Avg F1: {avg_f1:.2f}, min: {min_f1:.2f}, max: {max_f1:.2f}, Avg mAP: {avg_map:.2f}")

Main

In [15]:
def train(model, dataset, optimizer, batch_size, epochs, device):
  model.train()
  model.to(device)

  random.seed(42)
  torch.manual_seed(42)

  scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-6)

  for epoch in range(epochs):
    running_loss = 0.0
    running_f1 = 0.0
    for batch in tqdm(range(batch_size)):
      optimizer.zero_grad()
      episode = dataset[0]

      loss, f1, mAP = prototypical_loss(model, episode, device, use_focal=True, gamma=2.0, alpha=1.0)
      loss += l2_regularization(model, lambda_l2=1e-4)
      loss.backward()
      optimizer.step()

      running_loss += loss.item()
      running_f1 += f1.item()

    epoch_loss = running_loss / batch_size
    epoch_f1 = running_f1 / batch_size
    epoch_mAP = mAP
    print(f'Epoch {epoch+1} Complete - Loss: {epoch_loss:.4f} F1: {epoch_f1:.2f} mAP: {epoch_mAP:.2f}')
    check_val(model)
    scheduler.step()
    torch.cuda.empty_cache()

In [18]:
model = CnnGat(ConvEncoder, GATLayer, embed_size=245, device=device)
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr = 3e-4, weight_decay=1e-4)

train(model, dataset, optimizer, batch_size=1, epochs=10, device=device)



  0%|          | 0/1 [00:00<?, ?it/s]

support data size torch.Size([4, 3, 224, 224]), query data size torch.Size([4, 3, 224, 224])
Epoch 1 Complete - Loss: 255.6853 F1: 0.00 mAP: 0.30
support data size torch.Size([4, 3, 224, 224]), query data size torch.Size([4, 3, 224, 224])
support data size torch.Size([4, 3, 224, 224]), query data size torch.Size([4, 3, 224, 224])
support data size torch.Size([4, 3, 224, 224]), query data size torch.Size([4, 3, 224, 224])
support data size torch.Size([4, 3, 224, 224]), query data size torch.Size([4, 3, 224, 224])
support data size torch.Size([4, 3, 224, 224]), query data size torch.Size([4, 3, 224, 224])
support data size torch.Size([4, 3, 224, 224]), query data size torch.Size([4, 3, 224, 224])
support data size torch.Size([4, 3, 224, 224]), query data size torch.Size([4, 3, 224, 224])
support data size torch.Size([4, 3, 224, 224]), query data size torch.Size([4, 3, 224, 224])
support data size torch.Size([4, 3, 224, 224]), query data size torch.Size([4, 3, 224, 224])
support data size

  0%|          | 0/1 [00:00<?, ?it/s]

support data size torch.Size([4, 3, 224, 224]), query data size torch.Size([4, 3, 224, 224])
Epoch 2 Complete - Loss: 569.6907 F1: 0.00 mAP: 0.30
support data size torch.Size([4, 3, 224, 224]), query data size torch.Size([4, 3, 224, 224])
support data size torch.Size([4, 3, 224, 224]), query data size torch.Size([4, 3, 224, 224])
support data size torch.Size([4, 3, 224, 224]), query data size torch.Size([4, 3, 224, 224])
support data size torch.Size([4, 3, 224, 224]), query data size torch.Size([4, 3, 224, 224])
support data size torch.Size([4, 3, 224, 224]), query data size torch.Size([4, 3, 224, 224])
support data size torch.Size([4, 3, 224, 224]), query data size torch.Size([4, 3, 224, 224])
support data size torch.Size([4, 3, 224, 224]), query data size torch.Size([4, 3, 224, 224])
support data size torch.Size([4, 3, 224, 224]), query data size torch.Size([4, 3, 224, 224])
support data size torch.Size([4, 3, 224, 224]), query data size torch.Size([4, 3, 224, 224])
support data size

  0%|          | 0/1 [00:00<?, ?it/s]

support data size torch.Size([4, 3, 224, 224]), query data size torch.Size([4, 3, 224, 224])
Epoch 3 Complete - Loss: 1205.1200 F1: 0.00 mAP: 0.43
support data size torch.Size([4, 3, 224, 224]), query data size torch.Size([4, 3, 224, 224])
support data size torch.Size([4, 3, 224, 224]), query data size torch.Size([4, 3, 224, 224])
support data size torch.Size([4, 3, 224, 224]), query data size torch.Size([4, 3, 224, 224])
support data size torch.Size([4, 3, 224, 224]), query data size torch.Size([4, 3, 224, 224])
support data size torch.Size([4, 3, 224, 224]), query data size torch.Size([4, 3, 224, 224])
support data size torch.Size([4, 3, 224, 224]), query data size torch.Size([4, 3, 224, 224])
support data size torch.Size([4, 3, 224, 224]), query data size torch.Size([4, 3, 224, 224])
support data size torch.Size([4, 3, 224, 224]), query data size torch.Size([4, 3, 224, 224])
support data size torch.Size([4, 3, 224, 224]), query data size torch.Size([4, 3, 224, 224])
support data siz

In [None]:
# Testing
import matplotlib.pyplot as plt

In [None]:
test_episode = test_dataset[0]

# print shape for debugging
print(f"{test_episode['support_data'].shape}, {test_episode['support_label'].shape}, {test_episode['query_data'].shape}, {test_episode['query_label'].shape}")

In [None]:
# get prediction and performance
pred, f1 = predict(model, test_episode, device)
print(f"f1 of test episode {f1:.2f}")

In [None]:
# prepare data for visualization
imgs = test_episode['query_data'].cpu().numpy().transpose(0, 2, 3, 1)
col_names = ['Accelerating_and_revving_and_vroom', 'Accordion', 'Acoustic_guitar', 'Aircraft', 'Alarm', 'Animal', 'Applause', 'Bark', 'Bass_drum', 'Bass_guitar', 'Bathtub_(filling_or_washing)', 'Bell', 'Bicycle', 'Bicycle_bell', 'Bird', 'Bird_vocalization_and_bird_call_and_bird_song', 'Boat_and_Water_vehicle', 'Boiling', 'Boom', 'Bowed_string_instrument', 'Brass_instrument', 'Breathing', 'Burping_and_eructation', 'Bus', 'Buzz', 'Camera', 'Car', 'Car_passing_by', 'Cat', 'Chatter', 'Cheering', 'Chewing_and_mastication', 'Chicken_and_rooster', 'Child_speech_and_kid_speaking', 'Chime', 'Chink_and_clink', 'Chirp_and_tweet', 'Chuckle_and_chortle', 'Church_bell', 'Clapping', 'Clock', 'Coin_(dropping)', 'Computer_keyboard', 'Conversation', 'Cough', 'Cowbell', 'Crack', 'Crackle', 'Crash_cymbal', 'Cricket', 'Crow', 'Crowd', 'Crumpling_and_crinkling', 'Crushing', 'Crying_and_sobbing', 'Cupboard_open_or_close', 'Cutlery_and_silverware', 'Cymbal', 'Dishes_and_pots_and_pans', 'Dog', 'Domestic_animals_and_pets', 'Domestic_sounds_and_home_sounds', 'Door', 'Doorbell', 'Drawer_open_or_close', 'Drill', 'Drip', 'Drum', 'Drum_kit', 'Electric_guitar', 'Engine', 'Engine_starting', 'Explosion', 'Fart', 'Female_singing', 'Female_speech_and_woman_speaking', 'Fill_(with_liquid)', 'Finger_snapping', 'Fire', 'Fireworks', 'Fixed-wing_aircraft_and_airplane', 'Fowl', 'Frog', 'Frying_(food)', 'Gasp', 'Giggle', 'Glass', 'Glockenspiel', 'Gong', 'Growling', 'Guitar', 'Gull_and_seagull', 'Gunshot_and_gunfire', 'Gurgling', 'Hammer', 'Hands', 'Harmonica', 'Harp', 'Hi-hat', 'Hiss', 'Human_group_actions', 'Human_voice', 'Idling', 'Insect', 'Keyboard_(musical)', 'Keys_jangling', 'Knock', 'Laughter', 'Liquid', 'Livestock_and_farm_animals_and_working_animals', 'Male_singing', 'Male_speech_and_man_speaking', 'Mallet_percussion', 'Marimba_and_xylophone', 'Mechanical_fan', 'Mechanisms', 'Meow', 'Microwave_oven', 'Motor_vehicle_(road)', 'Motorcycle', 'Music', 'Musical_instrument', 'Ocean', 'Organ', 'Packing_tape_and_duct_tape', 'Percussion', 'Piano', 'Plucked_string_instrument', 'Pour', 'Power_tool', 'Printer', 'Purr', 'Race_car_and_auto_racing', 'Rail_transport', 'Rain', 'Raindrop', 'Ratchet_and_pawl', 'Rattle', 'Rattle_(instrument)', 'Respiratory_sounds', 'Ringtone', 'Run', 'Sawing', 'Scissors', 'Scratching_(performance_technique)', 'Screaming', 'Screech', 'Shatter', 'Shout', 'Sigh', 'Singing', 'Sink_(filling_or_washing)', 'Siren', 'Skateboard', 'Slam', 'Sliding_door', 'Snare_drum', 'Sneeze', 'Speech', 'Speech_synthesizer', 'Splash_and_splatter', 'Squeak', 'Stream', 'Strum', 'Subway_and_metro_and_underground', 'Tabla', 'Tambourine', 'Tap', 'Tearing', 'Telephone', 'Thump_and_thud', 'Thunder', 'Thunderstorm', 'Tick', 'Tick-tock', 'Toilet_flush', 'Tools', 'Traffic_noise_and_roadway_noise', 'Train', 'Trickle_and_dribble', 'Truck', 'Trumpet', 'Typewriter', 'Typing', 'Vehicle', 'Vehicle_horn_and_car_horn_and_honking', 'Walk_and_footsteps', 'Water', 'Water_tap_and_faucet', 'Waves_and_surf', 'Whispering', 'Whoosh_and_swoosh_and_swish', 'Wild_animals', 'Wind', 'Wind_chime', 'Wind_instrument_and_woodwind_instrument', 'Wood', 'Writing', 'Yell', 'Zipper_(clothing)']

In [None]:
# prepare grid
grid_size = int(np.ceil(np.sqrt(imgs.shape[0])))
fig, axes = plt.subplots(grid_size, grid_size, figsize=(12, 12))
axes = axes.flatten()

for i in range(imgs.shape[0]):
    ax = axes[i]
    img = imgs[i]
    ax.imshow(img)
    ax.axis("off")
    active_indices = torch.nonzero(pred[i], as_tuple=True)[0]
    selected_labels = [col_names[i] for i in active_indices]

    # get query labels
    query_labels = torch.nonzero(test_episode['query_label'][i], as_tuple=True)[0]
    query_label_names = [col_names[i] for i in query_labels]

    ax.set_title(f"Pred: {', '.join(selected_labels)}\nTrue: {', '.join(query_label_names)}", fontsize=7)

for j in range(i+1, len(axes)):
    axes[j].axis("off")

plt.tight_layout()
plt.show()

In [None]:
# get avg f1 in test set

checks = 200
running_f1 = 0.0
min_f1 = 2.0
max_f1 = -1.0
for i in tqdm(range(checks)):
  test_ep = test_dataset.getEpisode(6, 6, 6)
  _, f1 = predict(model, test_ep, device)
  min_f1 = min(min_f1, f1)
  max_f1 = max(max_f1, f1)
  running_f1 += f1

running_f1 = running_f1 / checks
print(f"Avg f1 in test set {running_f1:.2f}, min: {min_f1:.2f}, max: {max_f1:.2f}")