In [28]:
import pandas as pd
# from app.vision_automl.ml_engine.trainer import FabricTrainer
from app.vision_automl.ml_engine.model import ClassificationModel
from app.vision_automl.ml_engine.datamodule import ClassificationData
from torch import nn, optim

In [29]:
  # -H "Content-Type: multipart/form-data" \
  # -F "csv_file=@./sample_data/Garbage_Dataset_Classification/metadata.csv" \
  # -F "images_zip=@./sample_data/Garbage_Dataset_Classification/images.zip" \
  # -F "filename_column=filename" \
  # -F "label_column=label" \
  # -F "task_type=classification" \
  # -F "time_budget=30")

session_record = {
  "csv_file_path" : "/Users/smukherjee/Documents/Projects/ALFIE/alfie_automl_engine/uploaded_data/e42706fd-cdfe-4a42-818e-3363dd0b91b3/labels.csv", 
  "images_dir_path": "/Users/smukherjee/Documents/Projects/ALFIE/alfie_automl_engine/uploaded_data/e42706fd-cdfe-4a42-818e-3363dd0b91b3/images/images/", 
  "filename_column": "filename", 
  "label_column": "label", 
  "time_budget" : 30,
}

In [30]:
from typing import cast


datamodule = ClassificationData(
        csv_file=str(session_record["csv_file_path"]),
        root_dir=str(session_record["images_dir_path"]),
        img_col=str(session_record["filename_column"]),
        label_col=str(session_record["label_column"]),
        batch_size=64,
    )

# Train with FabricTrainer under a time budget
# ensure proper type for time limit
time_limit: float = float(cast(int, session_record["time_budget"]))


In [31]:
import time
from typing import Any, Dict, List, Optional

import lightning as L
import torch
from torch import nn, optim
from tqdm import tqdm

class FabricTrainer:
    def __init__(
        self,
        datamodule,
        model_class,
        model_kwargs: Optional[Dict[str, Any]] = None,
        optimizer_class=optim.AdamW,
        optimizer_kwargs: Optional[Dict[str, Any]] = None,
        loss_fn: nn.Module = nn.CrossEntropyLoss(),
        lr: float = 0.001,
        epochs: int = 1,
        time_limit: Optional[float] = None,
        device: str = "auto",
        callbacks: Optional[List[Any]] = None,
        input_dtype: torch.dtype = torch.float32,
        target_dtype: torch.dtype = torch.long,
    ):
        self.datamodule = datamodule
        self.model_class = model_class
        self.model_kwargs = model_kwargs or {}
        self.optimizer_class = optimizer_class
        self.optimizer_kwargs = optimizer_kwargs or {"lr": lr}
        self.loss_fn = loss_fn
        self.epochs = epochs
        self.time_limit = time_limit
        self.device = device
        self.callbacks = callbacks or []
        self.input_dtype = input_dtype
        self.target_dtype = target_dtype

        self.fabric = L.Fabric(devices=self.device)
        self._setup_model_optimizer()

    def _setup_model_optimizer(self):
        self.model = self.model_class(**self.model_kwargs)
        self.optimizer = self.optimizer_class(self.model.parameters(), **self.optimizer_kwargs)

        train_loader = self.datamodule.train_dataloader()
        val_loader = self.datamodule.val_dataloader()
        self.model, self.optimizer, self.train_loader, self.val_loader = self.fabric.setup(
            self.model, self.optimizer, train_loader, val_loader
        )
        self.test_loader = self.datamodule.test_dataloader()

    def _move_batch(self, imgs, labels):
        imgs = imgs.to(self.fabric.device, dtype=self.input_dtype)
        labels = labels.to(self.fabric.device, dtype=self.target_dtype)
        return imgs, labels

    def _check_time_limit(self, start_time: float) -> bool:
        if self.time_limit and (time.time() - start_time) > self.time_limit:
            self.fabric.print("Time limit reached. Stopping training.")
            return True
        return False

    def train_epoch(self, epoch: int, start_time: float) -> float:
        self.model.train()
        running_loss = 0.0
        for imgs, labels in tqdm(self.train_loader, desc=f"Epoch {epoch+1} Training", leave=False):
            if self._check_time_limit(start_time):
                return running_loss / max(1, len(self.train_loader))
            imgs, labels = self._move_batch(imgs, labels)
            self.optimizer.zero_grad()
            outputs = self.model(imgs)
            loss = self.loss_fn(outputs, labels)
            self.fabric.backward(loss)
            self.optimizer.step()
            running_loss += loss.item()
        return running_loss / len(self.train_loader)

    def validate(self, start_time: float):
        self.model.eval()
        val_loss, correct, total = 0.0, 0, 0
        with torch.no_grad():
            for imgs, labels in tqdm(self.val_loader, desc="Validation", leave=False):
                if self._check_time_limit(start_time):
                    break
                imgs, labels = self._move_batch(imgs, labels)
                outputs = self.model(imgs)
                loss = self.loss_fn(outputs, labels)
                val_loss += loss.item()
                preds = outputs.argmax(dim=1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)
        avg_loss = val_loss / max(1, len(self.val_loader))
        acc = correct / max(1, total)
        self.fabric.print(f"Val Loss: {avg_loss:.4f}, Val Acc: {acc:.4f}")
        return avg_loss, acc

    def test(self):
        self.model.eval()
        test_loss, correct, total = 0.0, 0, 0
        with torch.no_grad():
            for imgs, labels in tqdm(self.test_loader, desc="Testing"):
                imgs, labels = self._move_batch(imgs, labels)
                outputs = self.model(imgs)
                loss = self.loss_fn(outputs, labels)
                test_loss += loss.item()
                preds = outputs.argmax(dim=1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)
        avg_loss = test_loss / len(self.test_loader)
        acc = correct / total
        self.fabric.print(f"\nTest Loss: {avg_loss:.4f}, Test Acc: {acc:.4f}")
        return avg_loss, acc

    def fit(self):
        start_time = time.time()
        for epoch in range(self.epochs):
            train_loss = self.train_epoch(epoch, start_time)
            val_loss, val_acc = self.validate(start_time)
            logs = {"train_loss": train_loss, "val_loss": val_loss, "val_acc": val_acc}
            for cb in self.callbacks:
                cb.on_epoch_end(self, epoch, logs)
            if self._check_time_limit(start_time):
                break
        return self.test()




In [None]:
import onnx
def export_to_onnx(model):
    return torch.onnx.export(model)

In [38]:
inputs, labels = next(iter(datamodule.train_dataloader()))  

In [47]:
dummy = inputs[1].unsqueeze(0)

In [None]:
model = trainer.model.to(torch.device('cpu')).eval()

# If inputs[1] is shape [3, 224, 224], make it [1, 3, 224, 224]
dummy = inputs[1].unsqueeze(0)

torch.onnx.export(
    model,
    dummy,
    "model.onnx",
)

AttributeError: 'NoneType' object has no attribute 'flush'

In [28]:
from huggingface_hub import HfApi

def get_num_params_if_available(repo_id: str, revision: str | None = None):
    api = HfApi()
    info = api.model_info(repo_id, revision=revision, files_metadata=True)
    num_params = getattr(info, "safetensors", None)
    if num_params is not None:
        return num_params.total
    else:
        return None

def search_hf_for_pytorch_models_with_estimated_parameters(filter = "image-classification",limit: int = 3, sort = "downloads"):
    api = HfApi()
    models = api.list_models(
        filter=filter,
        library="pytorch",
        sort=sort, 
        direction=-1,        
        limit=limit,
    )
    
    model_list_pass_one = [
        {
            "model_id": m.id,
            "downloads": getattr(m, "downloads", None),
            "likes": getattr(m, "likes", None),
            "last_modified": getattr(m, "lastModified", None),
            "private": getattr(m, "private", None),
            "num_params" : get_num_params_if_available(m.id)
        }
        for m in models
    ]

    # cleaned_up_list_iff_params
    return [model for model in model_list_pass_one if model["num_params"] is not None]

In [29]:
candidate_models = search_hf_for_pytorch_models_with_estimated_parameters(filter = "image-classification", limit = 20)

Invalid model-index. Not loading eval results into CardData.
Invalid model-index. Not loading eval results into CardData.


In [31]:
from operator import itemgetter
newlist = sorted(candidate_models, key=itemgetter('num_params'), reverse = True)
newlist

[{'model_id': 'timm/vit_base_patch32_clip_448.laion2b_ft_in12k_in1k',
  'downloads': 367358,
  'likes': 4,
  'last_modified': None,
  'private': False,
  'num_params': 88337896},
 {'model_id': 'google/vit-base-patch16-224',
  'downloads': 3306984,
  'likes': 858,
  'last_modified': None,
  'private': False,
  'num_params': 86567656},
 {'model_id': 'timm/vit_base_patch16_224.augreg2_in21k_ft_in1k',
  'downloads': 373484,
  'likes': 9,
  'last_modified': None,
  'private': False,
  'num_params': 86567656},
 {'model_id': 'nateraw/vit-age-classifier',
  'downloads': 454519,
  'likes': 138,
  'last_modified': None,
  'private': False,
  'num_params': 85805577},
 {'model_id': 'trpakov/vit-face-expression',
  'downloads': 7903628,
  'likes': 78,
  'last_modified': None,
  'private': False,
  'num_params': 85804039},
 {'model_id': 'Falconsai/nsfw_image_detection',
  'downloads': 99558944,
  'likes': 811,
  'last_modified': None,
  'private': False,
  'num_params': 85800194},
 {'model_id': 'riz

In [25]:
candidate_models

[{'model_id': 'timm/mobilenetv3_small_100.lamb_in1k',
  'downloads': 112183087,
  'likes': 34,
  'last_modified': None,
  'private': False,
  'num_params': 2554968},
 {'model_id': 'Falconsai/nsfw_image_detection',
  'downloads': 99558944,
  'likes': 811,
  'last_modified': None,
  'private': False,
  'num_params': 85800194},
 {'model_id': 'timm/resnet50.a1_in1k',
  'downloads': 8285773,
  'likes': 39,
  'last_modified': None,
  'private': False,
  'num_params': None},
 {'model_id': 'trpakov/vit-face-expression',
  'downloads': 7903628,
  'likes': 78,
  'last_modified': None,
  'private': False,
  'num_params': 85804039},
 {'model_id': 'google/vit-base-patch16-224',
  'downloads': 3306984,
  'likes': 858,
  'last_modified': None,
  'private': False,
  'num_params': 86567656},
 {'model_id': 'timm/resnet18.a1_in1k',
  'downloads': 2729822,
  'likes': 12,
  'last_modified': None,
  'private': False,
  'num_params': 11699112},
 {'model_id': 'apple/mobilevit-small',
  'downloads': 1350424,
 

In [None]:
candidate_models

[{'model_id': 'timm/mobilenetv3_small_100.lamb_in1k',
  'downloads': 112183087,
  'likes': 34,
  'last_modified': None,
  'private': False,
  'num_params': 2554968},
 {'model_id': 'Falconsai/nsfw_image_detection',
  'downloads': 99558944,
  'likes': 811,
  'last_modified': None,
  'private': False,
  'num_params': 85800194},
 {'model_id': 'timm/resnet50.a1_in1k',
  'downloads': 8285773,
  'likes': 39,
  'last_modified': None,
  'private': False,
  'num_params': None},
 {'model_id': 'trpakov/vit-face-expression',
  'downloads': 7903628,
  'likes': 78,
  'last_modified': None,
  'private': False,
  'num_params': 85804039},
 {'model_id': 'google/vit-base-patch16-224',
  'downloads': 3306984,
  'likes': 858,
  'last_modified': None,
  'private': False,
  'num_params': 86567656},
 {'model_id': 'timm/resnet18.a1_in1k',
  'downloads': 2729822,
  'likes': 12,
  'last_modified': None,
  'private': False,
  'num_params': 11699112},
 {'model_id': 'apple/mobilevit-small',
  'downloads': 1350424,
 

In [18]:
estimate_model_size_mb("google/vit-base-patch16-224")

SafeTensorsInfo(parameters={'F32': 86567656}, total=86567656)

In [None]:
{'id': 'google/vit-base-patch16-224', 'author': 'google', 'sha': '3f49326eb077187dfe1c2a2bb15fbd74e6ab91e3', 'last_modified': datetime.datetime(2023, 9, 5, 15, 27, 12, tzinfo=datetime.timezone.utc), 'created_at': datetime.datetime(2022, 3, 2, 23, 29, 5, tzinfo=datetime.timezone.utc), 'private': False, 'gated': False, 'disabled': False, 'downloads': 3306984, 'downloads_all_time': None, 'likes': 858, 'library_name': 'transformers', 'gguf': None, 'inference': None, 'inference_provider_mapping': None, 'tags': ['transformers', 'pytorch', 'tf', 'jax', 'safetensors', 'vit', 'image-classification', 'vision', 'dataset:imagenet-1k', 'dataset:imagenet-21k', 'arxiv:2010.11929', 'arxiv:2006.03677', 'license:apache-2.0', 'autotrain_compatible', 'endpoints_compatible', 'region:us'], 'pipeline_tag': 'image-classification', 'mask_token': None, 'trending_score': None, 'card_data': {'base_model': None, 'datasets': ['imagenet-1k', 'imagenet-21k'], 'eval_results': None, 'language': None, 'library_name': None, 'license': 'apache-2.0', 'license_name': None, 'license_link': None, 'metrics': None, 'model_name': None, 'pipeline_tag': None, 'tags': ['vision', 'image-classification'], 'widget': [{'src': 'https://huggingface.co/datasets/mishig/sample_images/resolve/main/tiger.jpg', 'example_title': 'Tiger'}, {'src': 'https://huggingface.co/datasets/mishig/sample_images/resolve/main/teapot.jpg', 'example_title': 'Teapot'}, {'src': 'https://huggingface.co/datasets/mishig/sample_images/resolve/main/palace.jpg', 'example_title': 'Palace'}]}, 'widget_data': [{'src': 'https://huggingface.co/datasets/mishig/sample_images/resolve/main/tiger.jpg', 'example_title': 'Tiger'}, {'src': 'https://huggingface.co/datasets/mishig/sample_images/resolve/main/teapot.jpg', 'example_title': 'Teapot'}, {'src': 'https://huggingface.co/datasets/mishig/sample_images/resolve/main/palace.jpg', 'example_title': 'Palace'}], 'model_index': None, 'config': {'architectures': ['ViTForImageClassification'], 'model_type': 'vit'}, 'transformers_info': TransformersInfo(auto_model='AutoModelForImageClassification', custom_class=None, pipeline_tag='image-classification', processor='AutoImageProcessor'), 'siblings': [RepoSibling(rfilename='.gitattributes', size=744, blob_id='2514e63570ec0eb13d617d6a316ac52e9eed9746', lfs=None), RepoSibling(rfilename='README.md', size=5688, blob_id='550f72aac12a408fe77b7a2fdebc85057a77c0d1', lfs=None), RepoSibling(rfilename='config.json', size=69665, blob_id='21c61d3f8f3e50137a8c5fdd5bc9b085e286315a', lfs=None), RepoSibling(rfilename='flax_model.msgpack', size=346277993, blob_id='5f079399d5c4a8af31f21a05ad6913ea1a04f532', lfs=BlobLfsInfo(size=346277993, sha256='e0c809a1fe1d79c51c4da8a9ebd6c0923bac317ca469a33c6620977054149471', pointer_size=134)), RepoSibling(rfilename='model.safetensors', size=346293852, blob_id='d1e886a1fdb46f99c6cf9e603ed1c22fbf5c8453', lfs=BlobLfsInfo(size=346293852, sha256='1cea07110a4a47edc51420b2dda6f3b8b58e7256e8f44b4ea6aa9696162ccb5d', pointer_size=134)), RepoSibling(rfilename='preprocessor_config.json', size=160, blob_id='70fbc148eb26a06bac351d46fddc0a23037b4ce4', lfs=None), RepoSibling(rfilename='pytorch_model.bin', size=346351599, blob_id='a74020ed0ffd32c10585c3c21885b1a9e7f97d53', lfs=BlobLfsInfo(size=346351599, sha256='5f17067668129d23b52524f90a805e7d9914c276d90a59a13ebe81a09e40ceca', pointer_size=134)), RepoSibling(rfilename='tf_model.h5', size=346537664, blob_id='7665bbafc2d89b40ba2fa854a124762bfa528ccb', lfs=BlobLfsInfo(size=346537664, sha256='30f9125423060139b80ddae09daa8a1b612eb1eda8fc34a0b58cdfe920cbbc0f', pointer_size=134))], 'spaces': ['gunship999/SexyImages', 'Yntec/ToyWorld', 'llamameta/flux-pro-uncensored', 'Uthar/SexyReality', 'Yntec/PrintingPress', 'Nymbo/Compare-6', 'M2UGen/M2UGen-Demo', 'llamameta/fluxproV2', 'phenixrhyder/NSFW-ToyWorld', 'Yntec/ToyWorldXL', 'burtenshaw/autotrain-mcp', 'Yntec/blitz_diffusion', 'John6666/Diffusion80XX4sg', 'John6666/PrintingPress4', 'llamameta/fast-sd3.5-large', 'martynka/TasiaExperiment', 'yergyerg/ImgGenClone', 'Yntec/Image-Models-Test-2024', 'Nuno-Tome/simple_image_classifier', 'Yntec/Image-Models-Test-April-2024', 'DemiPoto/TestDifs', 'Abinivesh/Multi-models-prompt-to-image-generation', 'team-indain-image-caption/Hindi-image-captioning', 'tonyassi/product-recommendation', 'NativeAngels/Compare-6', 'abidlabs/vision-transformer', 'John6666/hfd_test_nostopbutton', 'Nymbo/Diffusion80XX4sg', 'DemiPoto/testSortModels', 'kaleidoskop-hug/PrintingPress', 'autonomous019/image_story_generator', 'Chakshu123/image-colorization-with-hint', 'Somnath3570/food_calories', 'Npps/Food_Indentification_and_Nutrition_Info', 'Yntec/MiniToyWorld', 'John6666/ToyWorld4', 'Ramos-Ramos/visual-emb-gam-probing', 'Chakshu123/sketch-colorization-with-hint', 'John6666/Diffusion80XX4g', 'SAITAN666/StableDiffusion35Large-Image-Models-Test-November-2024', 'NativeAngels/HuggingfaceDiffusion', 'Yntec/Image-Models-Test-December-2024', 'abidlabs/image-classifier', 'hysts/space-that-creates-model-demo-space', 'st0bb3n/Cam2Speech', 'jamesgray007/berkeley-ai-m3', 'John6666/Diffusion80XX4', 'K00B404/HuggingfaceDiffusion_custom', 'John6666/blitz_diffusion4', 'John6666/blitz_diffusion_builtin', 'eksemyashkina/clothes-segmentation', 'K00B404/SimpleBrothel', 'j0hngou/vision-diffmask', 'ipvikas/ImageProcessing', 'HighCWu/anime-colorization-with-hint', 'ClassCat/ViT-ImageNet-Classification', 'Yntec/Image-Models-Test-July-2024', 'Blane187/multi-diffusion', 'NativeAngels/ToyWorld', 'Uthar/LewdExperiments', 'Uthar/BodyPaint', 'Uthar/HRGiger', 'Uthar/HighFashion', 'Yntec/open-craiyon', 'Yntec/Image-Models-Test-March-2025', 'mmeendez/cnn_transformer_explainability', 'nickmuchi/Plant-Health-Classifier', 'Saiteja/leaf-ViT-classifier', 'dreamdrop-art/000555111', 'awacke1/MusicChatGenWithMuGen', 'Nuno-Tome/bulk_image_classifier', 'LucyintheSky/sketch-to-dress', 'andreped/vit-explainer', 'Somnath3570/food_calories_calculation', 'Shiladitya123Mondal/Food-Nutrition-app', 'swdqwewfw/Calorie_Calculator', 'Yeeezus/SexyImages', 'SharafeevRavil/test', 'John6666/MiniToyWorld', 'bryantmedical/oral_cancer', 'yiw/text', 'ThankGod/image-classifier', 'autonomous019/Story_Generator_v2', 'IPN/demo_', 'webis-huggingface-workshop/omar_demo', 'vebie91/spaces-image-classification-demo', 'suresh-subramanian/bean-classification', 'akhaliq/space-that-creates-model-demo-space', 'paschalc/ImageRecognitionDemo', 'peteralexandercharles/space-that-creates-model-demo-space', 'awacke1/MultiplayerImageRecognition-Gradio', 'mushroomsolutions/Gallery', 'xxx1/VQA_CAP_GPT', 'Kluuking/google-vit-base', 'Megareyka/imageRecognition', 'samavi/openai-clip-vit-base-patch32', 'HaawkeNeural/google-vit-base-patch16-224', 'jordonpeter01/M2UGen-Super-30s', 'JCRios/demo_image_classification', 'mrolando/classify_images'], 'safetensors': SafeTensorsInfo(parameters={'F32': 86567656}, total=86567656), 'security_repo_status': None, 'xet_enabled': None, 'lastModified': datetime.datetime(2023, 9, 5, 15, 27, 12, tzinfo=datetime.timezone.utc), 'cardData': {'base_model': None, 'datasets': ['imagenet-1k', 'imagenet-21k'], 'eval_results': None, 'language': None, 'library_name': None, 'license': 'apache-2.0', 'license_name': None, 'license_link': None, 'metrics': None, 'model_name': None, 'pipeline_tag': None, 'tags': ['vision', 'image-classification'], 'widget': [{'src': 'https://huggingface.co/datasets/mishig/sample_images/resolve/main/tiger.jpg', 'example_title': 'Tiger'}, {'src': 'https://huggingface.co/datasets/mishig/sample_images/resolve/main/teapot.jpg', 'example_title': 'Teapot'}, {'src': 'https://huggingface.co/datasets/mishig/sample_images/resolve/main/palace.jpg', 'example_title': 'Palace'}]}, 'transformersInfo': TransformersInfo(auto_model='AutoModelForImageClassification', custom_class=None, pipeline_tag='image-classification', processor='AutoImageProcessor'), '_id': '621ffdc136468d709f17b7d7', 'modelId': 'google/vit-base-patch16-224', 'usedStorage': 2550907501}

In [27]:
from transformers import AutoImageProcessor, AutoModelForImageClassification
test_model_id = candidate_models[0]['model_id']
processor = AutoImageProcessor.from_pretrained(test_model_id)
model = AutoModelForImageClassification.from_pretrained(test_model_id)

TimmWrapperForImageClassification(
  (timm_model): MobileNetV3(
    (conv_stem): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn1): BatchNormAct2d(
      16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
      (drop): Identity()
      (act): Hardswish()
    )
    (blocks): Sequential(
      (0): Sequential(
        (0): DepthwiseSeparableConv(
          (conv_dw): Conv2d(16, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=16, bias=False)
          (bn1): BatchNormAct2d(
            16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
            (drop): Identity()
            (act): ReLU(inplace=True)
          )
          (se): SqueezeExcite(
            (conv_reduce): Conv2d(16, 8, kernel_size=(1, 1), stride=(1, 1))
            (act1): ReLU(inplace=True)
            (conv_expand): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1))
            (gate): Hardsigmoid()
          )
          (conv_pw): Conv2d(1

In [36]:
from torch import nn
from transformers import AutoModelForImageClassification

class HFImageClassifier(nn.Module):
    def __init__(
        self,
        hf_model_name: str = "google/vit-base-patch16-224",
        num_classes: int = 2,
        id2label: dict | None = None,
        label2id: dict | None = None,
    ):
        super().__init__()
        self.model = AutoModelForImageClassification.from_pretrained(
            hf_model_name,
            num_labels=num_classes,
            id2label=id2label,
            label2id=label2id,
            ignore_mismatched_sizes=True,
        )

    def forward(self, x):
        outputs = self.model(pixel_values=x)
        return outputs.logits

In [None]:
model = HFImageClassifier(
    hf_model_name="google/vit-base-patch16-224",
    num_classes=len(train_ds.classes),
    id2label=id2label,
    label2id=label2id,
)
trainer = FabricTrainer(
    datamodule=datamodule,
    model_class=model,
    model_kwargs={
        "pixel_values": None
    },
    optimizer_class=optim.AdamW,
    optimizer_kwargs={"lr": 0.001},
    loss_fn=nn.CrossEntropyLoss(),
    epochs=50,  # upper bound; time_limit will cut earlier
    time_limit=time_limit,
)
# test_loss, test_acc = trainer.fit()

AttributeError: 'NoneType' object has no attribute 'to'