In [None]:
import os
import json
from glob import glob

import torch
import numpy as np
from tqdm.auto import tqdm
from torch.utils.data import TensorDataset, DataLoader, Dataset

from utils.func import read_jsonl
from utils.metric import evaluate
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis

In [None]:
num_classes = 1000
model_name = "LLaVA-7B"
batch_size = 128

### Convert Train set

In [None]:
files = sorted(glob(f"./output/{model_name}/ImageNet_Train/*.jsonl"))
x, y = [], []
if not os.path.exists(f"./output/{model_name}/ImageNet_Train/class/"):
    os.mkdir(f"./output/{model_name}/ImageNet_Train/class/")

for file in files:
    with open(file) as f:
        for line in tqdm(f):
            data = json.loads(line)
            if len(y) > 0 and data['label'] != y[-1]:
                x = torch.Tensor(x)
                y = torch.Tensor(y).long()
                torch.save((x, y), f"./output/{model_name}/ImageNet_Train/class/train_{y[-1]}.pth")
                x, y = [], []
            x.append(data['logits'])
            y.append(data['label'])
x = torch.Tensor(x)
y = torch.Tensor(y).long()
torch.save((x, y), f"./output/{model_name}/ImageNet_Train/class/train_{y[-1]}.pth")

In [None]:
class INTrainDataset(Dataset):
    def __init__(self, model_name, num_classes=1000):
        self.x, self.y = [], []
        for i in range(num_classes):
            x, y = torch.load(f"./output/{model_name}/ImageNet_Train/class/train_{i}.pth")
            self.x.append(x)
            self.y.append(y)
#             print(x.shape)
        self.x = torch.cat(self.x, dim=0)
        self.y = torch.cat(self.y, dim=0)
        
    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):

        return self.x[idx], self.y[idx]
train_dataset = INTrainDataset(model_name, num_classes)

### Val set

In [None]:
# Val Set
data = read_jsonl(f"./output/{model_name}/ImageNet_val.jsonl")
x_val = torch.Tensor([ins['logits'] for ins in tqdm(data) if ins['label'] < num_classes])
y_val = torch.Tensor([ins['label'] for ins in tqdm(data) if ins['label'] < num_classes])

x_val.shape, y_val.shape

In [None]:
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

val_dataset = TensorDataset(x_val, y_val)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

### Linear Classifier

In [None]:
clf = LinearDiscriminantAnalysis()
clf.fit(train_dataset.x, train_dataset.y)

In [None]:
y_pred = clf.predict(x_val.squeeze())
acc = (y_pred == y_val.numpy()).sum() / len(y_val)
print(f"{acc*100:.2f}")