In [1]:
import numpy as np
import aisuite as ai
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from classification.classification_models.vit import ClassificationModel, ModelConfig
from classification.classification_models.vit import ImageWoofDataset
from classification.classification_metrics.metrics import ClassificationMetrics
from pathlib import Path
import albumentations as A
from albumentations.pytorch import ToTensorV2

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
client = ai.Client()

messages = [
    {"role": "system", "content": "You are a helpful assistant that can help with computer vision research."},
    {"role": "user", "content": "What is the latest research in computer vision related to multimodal object detection?"}
]

response = client.chat.completions.create(
    model="ollama:gemini-3-flash-preview",
    messages=messages,
    max_tokens=1000,
    temperature=0.5
)

print(response.choices[0].message.content)

Research in multimodal object detection has shifted significantly over the last 24 months. We have moved away from "closed-set" detection (identifying 80 categories like in COCO) toward **Open-Vocabulary Detection (OVD)** and **Vision-Language (VL) Grounding**, where models can detect any object described in natural language.

Here is a breakdown of the latest research trends, key architectures, and paradigms in multimodal object detection as of late 2023 and 2024.

---

### 1. Open-Vocabulary Detection (OVD)
The goal of OVD is to detect objects that were not present in the labeled training set by leveraging knowledge from large-scale vision-language models like CLIP.

*   **YOLO-World (2024):** This is one of the most significant recent breakthroughs. It introduces a real-time open-vocabulary detector. By using a Vision-Language Path Aggregation Network (RepVL-PAN), it allows the model to detect objects based on custom prompts in real-time, making OVD viable for edge devices.
*   **Gr

# Error Analysis

In [3]:
class ClassificationErrorAnalysis(pl.LightningModule):
    def __init__(self, model:nn.Module, dataloader: DataLoader):
        super().__init__()
        self.model = model
        self.dataloader = dataloader
    def get_statistics(self):
        # TODO: implement parallel processing for this
        all_preds = []
        all_probs = []
        all_labels = []

        for batch in self.dataloader:
            x, y = batch
            y_hat = self.model(x)
            probs = torch.softmax(y_hat, dim=1)
            pred_classes = torch.argmax(probs, dim=1)

            all_probs.append(probs)
            all_preds.append(pred_classes)
            all_labels.append(y)
        
        all_probs = torch.cat(all_probs, dim=0)
        all_preds = torch.cat(all_preds, dim=0)
        all_labels = torch.cat(all_labels, dim=0)

        self.aucroc = ClassificationMetrics.get_aucroc(all_labels, all_probs)
        self.f1 = ClassificationMetrics.get_f1(all_labels, all_preds)
        self.accuracy = ClassificationMetrics.get_accuracy(all_labels, all_preds)
        self.precision = ClassificationMetrics.get_precision(all_labels, all_preds)
    

# Research Agent

In [4]:
class ClassificationResearchAgent(ClassificationErrorAnalysis):
    def __init__(self, 
                 user_prompt: str,  
                 trained_model: nn.Module,
                 dataloader: DataLoader,
                 researcher_model_name:"ollama:gemini-3-flash-preview"
    ):
    
        super().__init__(trained_model, dataloader)
        self.user_prompt = user_prompt
        self.ai = ai.Client()
        self.system_prompt = f"""
        You are an experienced computer vision practitioner. 
        You are given a set of statistics and a user prompt. 
        You need to analyze the statistics and recommend a set of changes to the user prompt to improve the model's performance.
        """
    def analyze_and_recommend(self):
        messages = [
            {"role": "system", "content": self.system_prompt},
            {"role": "user", "content": self.user_prompt}
        ]
        self.get_statistics()
        
        results = self.ai.chat.completions.create(
            model=self.model_name,
            messages=messages,
            max_tokens=1000,
            temperature=0.2,
        )
        
        
        return results.choices[0].message.content


# Research Agent Test

In [None]:
# load the model
from lightning.pytorch.callbacks import ModelCheckpoint
trained_model = ClassificationModel.load_from_checkpoint(
    "classification/classification_models/lightning_logs/version_4/checkpoints/tinynet-epoch=10-val_acc=0.8400.ckpt"
)
trained_model.eval()

ClassificationModel(
  (model): EfficientNet(
    (conv_stem): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn1): BatchNormAct2d(
      32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
      (drop): Identity()
      (act): SiLU(inplace=True)
    )
    (blocks): Sequential(
      (0): Sequential(
        (0): DepthwiseSeparableConv(
          (conv_dw): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (bn1): BatchNormAct2d(
            32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
            (drop): Identity()
            (act): SiLU(inplace=True)
          )
          (aa): Identity()
          (se): SqueezeExcite(
            (conv_reduce): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
            (act1): SiLU(inplace=True)
            (conv_expand): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
            (gate): Sigmoid()
          )
          (conv_p

: 

In [None]:
val_path = Path('/Users/jeremyong/Desktop/research_agent/dataset/imagewoof-160/val')
val_transform = A.Compose([
    A.Resize(224, 224),
    A.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
    ToTensorV2()
])

research_agent = ClassificationResearchAgent(
    user_prompt="What is the latest research in computer vision related to multimodal object detection?",
    trained_model=trained_model,
    dataloader=DataLoader(ImageWoofDataset(val_path, transform=val_transform), batch_size=1, shuffle=False),
    researcher_model_name="ollama:gemini-3-flash-preview"
)

recommendation = research_agent.analyze_and_recommend()
print(recommendation)