In [None]:
%pip install transformers
%pip install mlflow
%pip install datasets   #!pip

In [None]:
!pip install torch torchvision torchaudio

In [4]:
import torch
import mlflow
from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image
import requests
import pandas as pd
from datasets import load_dataset
from sklearn.metrics import f1_score

In [None]:
# Load CIFAR-10 test split 
test_ds = load_dataset("cifar10", split="test")

In [7]:
mlflow.autolog()
with mlflow.start_run() as run:
  mlflow.log_params({
    "model_name": "vit-base-patch16-224-cifar10"  
    })
  mlflow.log_param("dataset_name", "CIFAR-10 Test")

  feature_extractor = ViTFeatureExtractor.from_pretrained('nateraw/vit-base-patch16-224-cifar10')
  model = ViTForImageClassification.from_pretrained('nateraw/vit-base-patch16-224-cifar10')

  num_correct = 0
  true_labels = []
  predicted_labels = []
  for i, example in enumerate(test_ds):
    inputs = feature_extractor(images=example['img'], return_tensors="pt")
    outputs = model(**inputs)
    preds = outputs.logits.argmax(-1)

    true_labels.append(example['label'])
    predicted_labels.append(preds.item())

    if preds.item() == example['label']:
      num_correct += 1

  #calculate metrics
  accuracy = num_correct / len(test_ds) 
  f1 = f1_score(y_true=true_labels, y_pred=predicted_labels, average='macro')
  
  # Log metrics
  mlflow.log_metric("test_accuracy", accuracy)
  mlflow.log_metric("f1_score", f1) 

  # Log model
  mlflow.pytorch.log_model(model, "cifar10_model")

  print(run.info.artifact_uri)
  print(f"Run ID: {run.info.run_id}")

  #now run 'mlflow ui' in terminal to view it

  

2023/10/01 18:17:20 INFO mlflow.tracking.fluent: Autologging successfully enabled for transformers.
2023/10/01 18:17:20 INFO mlflow.tracking.fluent: Autologging successfully enabled for sklearn.


Downloading pytorch_model.bin:   0%|          | 0.00/343M [00:00<?, ?B/s]

file:///Users/travisrolle/Downloads/mlruns/0/0750b522eb8b4cb5a1081e77c9bbe39c/artifacts
Run ID: 0750b522eb8b4cb5a1081e77c9bbe39c


In [8]:
mlflow.set_tracking_uri("http://192.168.5.172:5000")

In [9]:
import sys
sys.version


'3.9.6 (default, May  7 2023, 23:32:44) \n[Clang 14.0.3 (clang-1403.0.22.14.1)]'

In [10]:
import transformers
import torch

print("Transformers version:", transformers.__version__)
print("PyTorch version:", torch.__version__)



Transformers version: 4.33.3
PyTorch version: 2.0.1
