In [1]:
import sys
sys.path.insert(0, "./docs/transformers/transformers/src")

from transformers import Blip2Processor
from lib.types import HFRepos, VQAParameters
from lib.daquar.daquar_generation import DaquarGeneration
from lib.easy_vqa.easyvqa_generation import EasyVQAGeneration
from lib.visualization import calculate_label_frequency
from lib.visualization import calculate_cardinality_and_density, create_label_frequency_boxplot
from lib.types import Suffix
%load_ext autoreload
%autoreload 2

# Load dependencies
processor = Blip2Processor.from_pretrained(HFRepos.BLIP2_OPT)

EASY_VQA_COMBINED = "easyvqa_images/easyvqa_combined"
EASY_VQA_FILTERED = "easyvqa_images/easyvqa_filtered"
DAQUAR_COMBINED = "daquar_images/daquar_combined"
DAQUAR_FILTERED = "daquar_images/daquar_filtered"
DAQUAR_PROPORTIONAL = "daquar_images/daquar_proportional"

  from .autonotebook import tqdm as notebook_tqdm


# Dataset Analysis
## Combined Easy-VQA dataset

Intuition around the boxplot:

1. **Outliers**: the items with over 17000 labels are likely to cause issues.
2. **Box and Whiskers**: the boxplot presents the interquartile range (IQR), and it is located near the bottom of the frequency axis. This indicates that the middle 50% of the label frequencies are clustered near the lower end of the distribution.

Potential issues:

1. **Imbalanced Training**: The model may become biased towards frequently occurring labels, leading to poor generalization and underperformance on less frequent labels.
2. **Overfitting to Common Labels**: The model might overfit to the highly frequent outlier labels, impairing its ability to accurately predict or generate responses for less common labels.
3. **Difficulty in Learning Rare Labels**: Underrepresented labels may not be learned effectively, causing the model to struggle with questions requiring these labels during inference.
4. **Reduced Model Robustness**: The model's robustness may be compromised, resulting in poor performance on unseen data, particularly if it contains more instances of less frequent labels.


### Box plots

In [67]:
args = VQAParameters(Suffix.All, recompute=True) # using combined dataset
args.processor = processor
dataset = EasyVQAGeneration(args)
create_label_frequency_boxplot(dataset, path=f"{EASY_VQA_COMBINED}_all")

INFO:lib.easy_vqa.easyvqa_base:Read combined dataset, length: 48248


Q1: 813.0
Q2: 826.0
Q3: 2105.0
IQR: 1292.0
Lower Whisker: -1125.0
Upper Whisker: 4043.0
Number of unique labels: 13
Mean frequency: 3711.38
Median frequency: 826.00
Number of outliers: 2
Total number of items: 48248


### Label Distribution Frequency

In [68]:
args = VQAParameters("all")
args.processor = processor
dataset = EasyVQAGeneration(args)
calculate_label_frequency(dataset, path=f"{EASY_VQA_COMBINED}_all")

INFO:lib.dataset_base:Loaded 48248 items from /home/atomwalk12/Dropbox (Old)/notes/vision/project/BeyondVisionQA/VisualQA/data/easy-vqa/generation/all.pkl


Total number of items:
 48248
Top 10 most frequent labels:
            Frequency
yes            17840
no             17703
rectangle       2151
circle          2105
triangle        2003
teal             841
green            826
gray             822
black            817
blue             813
Top 10 least frequent labels:
           Frequency
circle         2105
triangle       2003
teal            841
green           826
gray            822
black           817
blue            813
yellow          808
red             761
brown           758


## Preprocessing Results
The idea is that we would like to evenly distribute the data across the classes.
To do this, we can use the following approach:

1. Set a minimum of 100 samples per class for the training set and 25 for the validation set.
2. For classes with less than 125 total samples, put all in training.
3. For classes with 125-3125 samples, use an 80-20 split.
4. For classes with over 3125 samples (outliers), cap the training at 1700 and validation at 400.


This approach would:
1. Ensure all classes are represented in both sets.
2. Prevent rare classes from being excluded.
3. Limit the influence of extremely common classes.
4. Maintain a reasonable overall train-validation split.

In [69]:
args = VQAParameters("train", recompute=True, use_proportional_split=True)
args.processor = processor
dataset_train = EasyVQAGeneration(args)
calculate_label_frequency(dataset_train, path=f"{EASY_VQA_FILTERED}_train")

args = VQAParameters("val", recompute=True, use_proportional_split=True)
args.processor = processor
dataset_val = EasyVQAGeneration(args)
calculate_label_frequency(dataset_val, path=f"{EASY_VQA_FILTERED}_val")


args = VQAParameters("test", recompute=True, use_proportional_split=True)
args.processor = processor
dataset_test = EasyVQAGeneration(args)
calculate_label_frequency(dataset_test, path=f"{EASY_VQA_FILTERED}_test")

INFO:lib.easy_vqa.easyvqa_base:Read combined dataset, length: 48248
Map: 100%|██████████| 13558/13558 [00:01<00:00, 13223.11 examples/s]
Casting to class labels: 100%|██████████| 13558/13558 [00:00<00:00, 317095.79 examples/s]
INFO:lib.easy_vqa.easyvqa_base:Read train dataset, length: 10846
INFO:lib.dataset_base:Preparing data for training
Map: 100%|██████████| 10846/10846 [00:00<00:00, 12497.93 examples/s]


Total number of items:
 10846
Top 10 most frequent labels:
            Frequency
rectangle       1370
no              1360
yes             1351
circle          1347
triangle        1287
teal             532
green            530
yellow           526
gray             524
black            521
Top 10 least frequent labels:
           Frequency
circle         1347
triangle       1287
teal            532
green           530
yellow          526
gray            524
black           521
blue            520
red             492
brown           486


INFO:lib.easy_vqa.easyvqa_base:Read combined dataset, length: 48248
INFO:lib.easy_vqa.easyvqa_base:Read val dataset, length: 3347
INFO:lib.dataset_base:Preparing data for training
Map: 100%|██████████| 3347/3347 [00:00<00:00, 12735.26 examples/s]


Total number of items:
 3347
Top 10 most frequent labels:
            Frequency
rectangle        431
circle           421
triangle         401
no               400
yes              400
teal             169
green            166
gray             165
black            164
blue             163
Top 10 least frequent labels:
         Frequency
no            400
yes           400
teal          169
green         166
gray          165
black         164
blue          163
yellow        162
red           153
brown         152


INFO:lib.easy_vqa.easyvqa_base:Read combined dataset, length: 48248
Map: 100%|██████████| 13558/13558 [00:01<00:00, 12871.19 examples/s]
Casting to class labels: 100%|██████████| 13558/13558 [00:00<00:00, 305920.69 examples/s]
INFO:lib.easy_vqa.easyvqa_base:Read test dataset, length: 2712
INFO:lib.dataset_base:Preparing data for training
Map: 100%|██████████| 2712/2712 [00:00<00:00, 12256.04 examples/s]


Total number of items:
 2712
Top 10 most frequent labels:
            Frequency
no               340
yes              338
circle           337
rectangle        326
triangle         323
red              140
gray             137
blue             135
teal             132
brown            132
Top 10 least frequent labels:
            Frequency
rectangle        326
triangle         323
red              140
gray             137
blue             135
teal             132
brown            132
yellow           131
green            126
black            115


In [70]:
create_label_frequency_boxplot(dataset_train, path=f"{EASY_VQA_FILTERED}_train")
create_label_frequency_boxplot(dataset_val, path=f"{EASY_VQA_FILTERED}_val")

Q1: 521.0
Q2: 530.0
Q3: 1347.0
IQR: 826.0
Lower Whisker: -718.0
Upper Whisker: 2586.0
Number of unique labels: 13
Mean frequency: 834.31
Median frequency: 530.00
Number of outliers: 0
Total number of items: 10846
Q1: 163.0
Q2: 166.0
Q3: 400.0
IQR: 237.0
Lower Whisker: -192.5
Upper Whisker: 755.5
Number of unique labels: 13
Mean frequency: 257.46
Median frequency: 166.00
Number of outliers: 0
Total number of items: 3347


# Combined DAQUAM dataset
## Combined dataset analysis
### Box Plots

Intuition around the boxplot:
1. The dataset is highly unbalanced, with many classes having only a few samples.
2. The boxplot shows that most of the labels are clustered near the lower end of the distribution, with a few classes having a large number of labels.

In [71]:
args = VQAParameters(Suffix.All, recompute=True) # using combined dataset
args.processor = processor
dataset = DaquarGeneration(args)
create_label_frequency_boxplot(dataset, path=f"{DAQUAR_COMBINED}_all", multilabel=True)

Q1: 2.0
Q2: 4.0
Q3: 13.75
IQR: 11.75
Lower Whisker: -15.625
Upper Whisker: 31.375
Number of unique labels: 582
Mean frequency: 24.62
Median frequency: 4.00
Number of outliers: 92
Total number of items: 12468


### Label Frequency

In [72]:
calculate_label_frequency(dataset, path=f"{DAQUAR_COMBINED}_all", multilabel=True)

Total number of items:
 12468
Top 10 most frequent labels:
          Frequency
2              554
table          469
chair          412
lamp           351
white          349
photo          341
3              327
picture        308
window         284
books          281
Top 10 least frequent labels:
                  Frequency
dish_rack                1
dog_cage                 1
file_stand               1
binder                   1
chest                    1
soap_holder              1
iron_grill               1
cat_cage                 1
staple_remover           1
indoor_fountain          1


## Filtered Data
### Box Plots

The problem is handled in the following way:
1. **Combine the training and validation sets.** In order to have a general understanding of the amount of data per label available.
1. **Remove infrequent labels.** Remove all labels with less than 50 training examples, since they cause can cause a of noise during training.
2. **Do not remove upper-end outliers.** Since the task is relatively difficult, we'd like to keep as much data as possible, even the outliers at the upper-end of the distribution.
3. **Stratify the split.** We use an 80-20 stratified split to ensure that the training and validation sets are representative of the original dataset.
4. **Class specific weighting.** We use class weights to address the issue of class imbalance during training.

In [73]:
args = VQAParameters(Suffix.All, use_filtered_split=True) # using combined dataset
args.processor = processor
filtered_dataset = DaquarGeneration(args)
create_label_frequency_boxplot(filtered_dataset, path=f"{DAQUAR_FILTERED}_filtered", multilabel=True)

INFO:lib.dataset_base:Loaded 9523 items from /home/atomwalk12/Dropbox (Old)/notes/vision/project/BeyondVisionQA/VisualQA/data/daquar/generation/all.pkl


Q1: 73.75
Q2: 104.5
Q3: 205.0
IQR: 131.25
Lower Whisker: -123.125
Upper Whisker: 401.875
Number of unique labels: 68
Mean frequency: 154.62
Median frequency: 104.50
Number of outliers: 3
Total number of items: 9523


### Label Frequency

In [74]:
calculate_label_frequency(filtered_dataset, path=f"{DAQUAR_FILTERED}_filtered", multilabel=True)

Total number of items:
 9523
Top 10 most frequent labels:
          Frequency
2              554
table          469
chair          412
lamp           351
white          349
photo          341
3              327
picture        308
window         284
books          281
Top 10 least frequent labels:
                             Frequency
light                              64
bowl                               60
basket                             58
stove                              57
night_stand                        56
gray                               56
toilet                             53
bottle_of_hand_wash_liquid         52
ornamental_plant                   50
plant                              50


## Preprocessing results
### Outliers
**Capping.** It is likely not a good idea to cap the number of examples per class, as this can lead to the loss of important data and information. Instead, we can use stratified sampling to ensure that each class is represented in the training and validation sets.


**Stratified Sampling.** Instead of capping number of examples per class, we can instead use stratified sampling to ensure that each class is represented in the training and validation sets.

**Class Weighting.** Moreover, by using class weights, we can ensure that class imbalances are addressed during the actual training process. This assures that the model pays more attention to the rare classes.

### Box Plots

In [2]:
args = VQAParameters(Suffix.Train, use_proportional_split=True) # using combined dataset
args.processor = processor
train_dataset = DaquarGeneration(args)
create_label_frequency_boxplot(train_dataset, f"{DAQUAR_PROPORTIONAL}_train", multilabel=True)

args = VQAParameters(Suffix.Val, use_proportional_split=True) # using combined dataset
args.processor = processor
val_dataset = DaquarGeneration(args)
create_label_frequency_boxplot(val_dataset, f"{DAQUAR_PROPORTIONAL}_val", multilabel=True)

INFO:lib.dataset_base:Loaded 7604 items from /home/atomwalk12/Dropbox (Old)/notes/vision/project/BeyondVisionQA/VisualQA/data/daquar/generation/train.pkl
INFO:lib.dataset_base:Loaded 1919 items from /home/atomwalk12/Dropbox (Old)/notes/vision/project/BeyondVisionQA/VisualQA/data/daquar/generation/val.pkl


Q1: 58.75
Q2: 83.5
Q3: 164.0
IQR: 105.25
Lower Whisker: -99.125
Upper Whisker: 321.875
Number of unique labels: 68
Mean frequency: 123.72
Median frequency: 83.50
Number of outliers: 3
Total number of items: 7604
Q1: 15.0
Q2: 21.0
Q3: 41.0
IQR: 26.0
Lower Whisker: -24.0
Upper Whisker: 80.0
Number of unique labels: 68
Mean frequency: 30.90
Median frequency: 21.00
Number of outliers: 3
Total number of items: 1919


### Cardinality and Density

- A high label cardinality indicates that most instances are associated with multiple labels.
- A low label density might suggest that not all labels are used frequently.

In [3]:
calculate_cardinality_and_density(train_dataset)
calculate_cardinality_and_density(val_dataset)

The number of unique labels: 68
Label Cardinality: 1.106391372961599
Label Density: 0.01627046136708234
Average number of labels per sample: 1.11
The number of unique labels: 68
Label Cardinality: 1.0948410630536738
Label Density: 0.01610060386843638
Average number of labels per sample: 1.09


### Label Frequency

In [77]:
label_frequency_train = calculate_label_frequency(train_dataset, path=f"{DAQUAR_PROPORTIONAL}_train", multilabel=True)
label_frequency_val = calculate_label_frequency(val_dataset, path=f"{DAQUAR_PROPORTIONAL}_val", multilabel=True)

Total number of items:
 7604
Top 10 most frequent labels:
          Frequency
2              443
table          375
chair          330
lamp           281
white          279
photo          273
3              262
picture        247
window         227
books          225
Top 10 least frequent labels:
                             Frequency
light                              51
bowl                               48
stove                              46
basket                             46
gray                               45
night_stand                        45
toilet                             42
bottle_of_hand_wash_liquid         42
plant                              40
ornamental_plant                   40
Total number of items:
 1919
Top 10 most frequent labels:
          Frequency
2              111
table           94
chair           82
white           70
lamp            70
photo           68
3               65
picture         61
window          57
books           56
Top 10 least fr

In [13]:
from lib.models.feature_visualizer import FeatureVisualizer
import pickle

args = VQAParameters(Suffix.Train, recompute=True) # using combined dataset
args.processor = processor
dataset = EasyVQAGeneration(args)

split = "train"
#path = f"data/models/easy_vqa/classifier/2088502146/features_{split}.pkl" represent the raw features not classifier outputs
path = f"data/models/easy_vqa/classifier/710142242/features_{split}.pkl" # classifier outputs
data = pickle.load(open(path, "rb"))
features = data["features"]
labels = data["labels"]

feature_visualizer = FeatureVisualizer(id_to_answer=dataset.id_to_answer, dataset_name="easy_vqa")
feature_visualizer.set_features(features, labels, split)
feature_visualizer.visualize_features_with_umap()

INFO:lib.easy_vqa.easyvqa_base:Read combined dataset, length: 48248
INFO:lib.dataset_base:Preparing data for training
Map: 100%|██████████| 48248/48248 [00:02<00:00, 17051.95 examples/s]

You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any i

In [11]:

from lib.models.feature_visualizer import FeatureVisualizer
import pickle

args = VQAParameters(Suffix.Train, recompute=True) # using combined dataset
args.processor = processor
dataset = DaquarGeneration(args)

split = "train"
#path = f"data/models/easy_vqa/classifier/2088502146/features_{split}.pkl" represent the raw features not classifier outputs
path = f"data/models/daquar/classifier/452947361/features_{split}.pkl" # classifier outputs
data = pickle.load(open(path, "rb"))
features = data["features"]
labels = data["labels"]

feature_visualizer = FeatureVisualizer(id_to_answer=dataset.id_to_answer, dataset_name="daquar")
feature_visualizer.set_features(features, labels, split)
feature_visualizer.visualize_features_with_umap()

INFO:lib.dataset_base:Preparing data for training
Map: 100%|██████████| 12468/12468 [00:00<00:00, 18921.54 examples/s]

You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.


n_jobs value 1 overrid