In [10]:
import numpy as np
import aisuite as ai
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from classification.classification_models.vit import ClassificationModel, ModelConfig
from classification.classification_models.vit import ImageWoofDataset
from classification.classification_metrics.metrics import ClassificationMetrics
from pathlib import Path
import albumentations as A
from tqdm import tqdm
from albumentations.pytorch import ToTensorV2
import json

from utils import extract_improvements

In [13]:
client = ai.Client()
statistics = json.load(open("statistics.json"))
system_prompt = f"""
        You are an experienced computer vision practitioner. 
        You are given a set of statistics and a user prompt. 
        You need to analyze the statistics and recommend a set of changes to the user prompt to improve the model's performance.
"""

user_prompt = f"""
I have a model that is trained on the ImageWoof dataset.
The model is a TinyNet model.
The model is trained for 10 epochs.
The model is trained with a learning rate of 0.001.
The model is trained with a weight decay of 0.01.
The model is trained with a batch size of 2.
The model is trained with a num_workers of 4.

The results are as follows:
{statistics}

I want to improve the model's performance.
What changes can I make besides the hyperparameters
such as data centric approaches or concatenating traditional computer vision methods such as edge detection methods
along the channels of the model as input to improve the model's performance?

I want you to give me a list of changes that I can make to the model to improve its performance where the improvements should 
be enclosed in <improvement> and </improvement> tags without any punctuations, explanations, adjectives, just the name of the changes.
"""

messages = [
    {"role": "system", "content": system_prompt},
    {"role": "user", "content": user_prompt}
]



response = client.chat.completions.create(
    model="ollama:gemini-3-flash-preview",
    messages=messages,
    max_tokens=1000,
    temperature=0.5
)

print(response.choices[0].message.content)

<improvement>Mixup</improvement>
<improvement>CutMix</improvement>
<improvement>RandAugment</improvement>
<improvement>AutoAugment</improvement>
<improvement>Canny edge detection</improvement>
<improvement>Sobel filter concatenation</improvement>
<improvement>Gabor filter concatenation</improvement>
<improvement>Histogram equalization</improvement>
<improvement>Test time augmentation</improvement>
<improvement>ImageNet pretraining</improvement>
<improvement>Local binary patterns</improvement>
<improvement>Laplacian filter concatenation</improvement>
<improvement>Color jitter</improvement>
<improvement>Erasing</improvement>
<improvement>HOG feature concatenation</improvement>


In [11]:
improvement_list = extract_improvements(response.choices[0].message.content)
print(improvement_list)

['Apply Mixup and CutMix data augmentation', 'Concatenate Canny edge detection maps as an additional input channel', 'Initialize the model with ImageNet pretrained weights', 'Apply RandAugment transformation pipeline', 'Incorporate Sobel filter gradients along the channel dimension', 'Increase input image resolution to capture finer spatial details', 'Apply Test Time Augmentation during inference', 'Concatenate Local Binary Patterns features to the input image', 'Implement Gabor filter banks as fixed preprocessing layers', 'Utilize Fourier Transform magnitude components as auxiliary inputs', 'Perform data cleaning to remove mislabeled or noisy images in the training set', 'Apply color space conversion such as Lab or YCrCb and concatenate with RGB channels', 'Incorporate Histogram of Oriented Gradients as additional feature maps', 'Use AutoAugment specifically tuned for ImageNet style datasets', 'Apply Gaussian blurring or bilateral filtering as a noise reduction preprocessing step']
