In [1]:
import dataclasses
import time

import torch
import torch.nn as nn
import wandb
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.io import decode_image
from torchvision.models import EfficientNet_V2_S_Weights, efficientnet_v2_s
from torchvision.transforms import v2
from tqdm.notebook import tqdm
#from screenshot_dataset import ScreenshotDataset

In [10]:
@dataclasses.dataclass
class RunConfig:
    batch_size: int
    learning_rate: float
    momentum: float
    num_epochs: int
    architecture: str
config = RunConfig(
    batch_size = 8,
    learning_rate = 0.001,
    momentum = 0.0,
    num_epochs = 100,
    architecture = "EfficientNetV2_s",
)

In [11]:
augmentations = v2.Compose([
    v2.ToImage(),
    v2.RandomResizedCrop(size=(224, 224), antialias=True),
    v2.RandomHorizontalFlip(p=0.5),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.229, 0.224, 0.225]),
])

ds = ImageFolder("./screenshot_data/", transform=augmentations)
train, validate = torch.utils.data.random_split(ds, [0.8, 0.2])
train_dataloader = DataLoader(dataset=train, batch_size=config.batch_size, num_workers=4, pin_memory=True, shuffle=True)
validation_dataloader = DataLoader(dataset=validate, batch_size=config.batch_size, num_workers=4, pin_memory=False, shuffle=False)
print(ds.classes)

['not_screenshot', 'screenshot']


In [4]:
#torch.backends.quantized.engine = ""
device = torch.device("cuda")

In [31]:
class ScreenshotModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.stem = efficientnet_v2_s(weights=EfficientNet_V2_S_Weights.IMAGENET1K_V1)
        self.head = nn.Linear(1000, 2)
    def forward(self, x):
        return self.head(self.stem(x))

In [32]:
model = ScreenshotModel().to(device)
preprocessing = EfficientNet_V2_S_Weights.IMAGENET1K_V1.transforms()
optimizer = torch.optim.SGD(model.parameters(), lr=config.learning_rate, momentum=config.momentum)
loss_fn = torch.nn.CrossEntropyLoss()

dataclasses.asdict(config)
run = wandb.init(
    # Set the wandb project where this run will be logged.
    project="screenshot_classifier",
    # Track hyperparameters and run metadata.
    config=dataclasses.asdict(config),
)

In [33]:
def train_epoch(model, data, optimizer, loss_fn):
    model.train()
    sample_count = 0
    total_loss = 0.0
    for i, (image, target) in enumerate(tqdm(data)):
        start_time = time.time()
        with torch.amp.autocast('cuda'):
            image = preprocessing(image)
            image, target = image.to(device), target.to(device)
            output = model(image)
            loss = loss_fn(output, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            sample_count += image.shape[0]
            total_loss += loss.item()
        end_time = time.time()
    print(f"Train epoch loss: {total_loss}\tAvg loss: {total_loss/sample_count}")
    return total_loss

def run_eval(model, data, loss_fn):
    model.eval()
    sample_count = 0
    total_loss = 0.0
    for i, (image, target) in enumerate(tqdm(data)):
        with torch.amp.autocast('cuda'):
            image = preprocessing(image)
            image, target = image.to(device), target.to(device)
            output = model(image)
            loss = loss_fn(output, target)
            total_loss += loss.item()
            sample_count += image.shape[0]
    print(f"Eval: total loss: {total_loss}\tavg loss: {total_loss/sample_count}")
    return total_loss

In [34]:
best_loss = 1e32
for epoch_idx in range(config.num_epochs):
    print(f"Epoch {epoch_idx}")
    batch_loss = train_epoch(model, train_dataloader, optimizer, loss_fn)
    validation_loss = run_eval(model, validation_dataloader, loss_fn)
    if validation_loss < best_loss:
        best_loss = validation_loss
        torch.save(model, f"screenshot_checkpoint_{epoch_idx}_{validation_loss}.pt")
    wandb.log({"batch_loss": batch_loss, "validation_loss": validation_loss})
run.finish()

Epoch 0


  0%|          | 0/43 [00:00<?, ?it/s]

Train epoch loss: 23.344816595315933	Avg loss: 0.06845987271353646


  0%|          | 0/11 [00:00<?, ?it/s]

Eval: total loss: 3.7736541777849197	avg loss: 0.044395931503352
Epoch 1


  0%|          | 0/43 [00:00<?, ?it/s]

Train epoch loss: 14.818244755268097	Avg loss: 0.043455263211929905


  0%|          | 0/11 [00:00<?, ?it/s]

Eval: total loss: 3.268159106373787	avg loss: 0.038448930663221026
Epoch 2


  0%|          | 0/43 [00:00<?, ?it/s]

Train epoch loss: 12.604303315281868	Avg loss: 0.03696276632047469


  0%|          | 0/11 [00:00<?, ?it/s]

Eval: total loss: 2.441347122192383	avg loss: 0.02872173084932215
Epoch 3


  0%|          | 0/43 [00:00<?, ?it/s]

Train epoch loss: 11.971184968948364	Avg loss: 0.03510611427844095


  0%|          | 0/11 [00:00<?, ?it/s]

Eval: total loss: 1.8718202114105225	avg loss: 0.0220214142518885
Epoch 4


  0%|          | 0/43 [00:00<?, ?it/s]

Train epoch loss: 8.824104070663452	Avg loss: 0.025877138037136222


  0%|          | 0/11 [00:00<?, ?it/s]

Eval: total loss: 2.29651565477252	avg loss: 0.027017831232617882
Epoch 5


  0%|          | 0/43 [00:00<?, ?it/s]

Train epoch loss: 8.149596005678177	Avg loss: 0.023899108521050372


  0%|          | 0/11 [00:00<?, ?it/s]

Eval: total loss: 2.041160386055708	avg loss: 0.024013651600655387
Epoch 6


  0%|          | 0/43 [00:00<?, ?it/s]

Train epoch loss: 7.322765275835991	Avg loss: 0.021474384973126072


  0%|          | 0/11 [00:00<?, ?it/s]

Eval: total loss: 1.7628811374306679	avg loss: 0.02073977808741962
Epoch 7


  0%|          | 0/43 [00:00<?, ?it/s]

Train epoch loss: 7.95750343054533	Avg loss: 0.023335787186349943


  0%|          | 0/11 [00:00<?, ?it/s]

Eval: total loss: 1.7737454883754253	avg loss: 0.020867593980887358
Epoch 8


  0%|          | 0/43 [00:00<?, ?it/s]

Train epoch loss: 7.2074248641729355	Avg loss: 0.021136143296694825


  0%|          | 0/11 [00:00<?, ?it/s]

Eval: total loss: 1.7000700682401657	avg loss: 0.020000824332237245
Epoch 9


  0%|          | 0/43 [00:00<?, ?it/s]

Train epoch loss: 6.035798013210297	Avg loss: 0.0177002874287692


  0%|          | 0/11 [00:00<?, ?it/s]

Eval: total loss: 1.6634170189499855	avg loss: 0.01956961198764689
Epoch 10


  0%|          | 0/43 [00:00<?, ?it/s]

Train epoch loss: 6.071434259414673	Avg loss: 0.017804792549603148


  0%|          | 0/11 [00:00<?, ?it/s]

Eval: total loss: 1.3524646311998367	avg loss: 0.015911348602351022
Epoch 11


  0%|          | 0/43 [00:00<?, ?it/s]

Train epoch loss: 6.244300380349159	Avg loss: 0.01831173132067202


  0%|          | 0/11 [00:00<?, ?it/s]

Eval: total loss: 1.6642730236053467	avg loss: 0.019579682630651137
Epoch 12


  0%|          | 0/43 [00:00<?, ?it/s]

ERROR: Unexpected segmentation fault encountered in worker.
 

RuntimeError: DataLoader worker (pid(s) 845960) exited unexpectedly

In [35]:
torch.save(model, "screenshot_checkpoint.pt")

# Export and Validate:

In [48]:
# 
model.eval()
input_tensor = torch.rand((1, 3, 224, 224), dtype=torch.float32).to(device)
torch.onnx.export(
    model,                  # model to export
    (input_tensor,),        # inputs of the model,
    "screenshot_checkpoint.onnx",        # filename of the ONNX model
    #dynamic_shapes=[{0: "batch"}],
    #dynamic_axes={"input":{0: "batch"}},  # When dynamo is False
    input_names=["input"],  # Rename inputs for the ONNX model
    output_names=["output"],
    dynamo=True             # True or False to select the exporter to use
)

[torch.onnx] Obtain model graph for `ScreenshotModel([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `ScreenshotModel([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
Applied 220 of general pattern rewrite rules.


ONNXProgram(
    model=
        <
            ir_version=10,
            opset_imports={'': 20},
            producer_name='pytorch',
            producer_version='2.9.1+cu128',
            domain=None,
            model_version=None,
        >
        graph(
            name=main_graph,
            inputs=(
                %"input"<FLOAT,[1,3,224,224]>
            ),
            outputs=(
                %"output"<FLOAT,[1,2]>
            ),
            initializers=(
                %"stem.features.0.0.weight"<FLOAT,[24,3,3,3]>{Tensor(...)},
                %"stem.features.1.0.block.0.0.weight"<FLOAT,[24,24,3,3]>{Tensor(...)},
                %"stem.features.1.1.block.0.0.weight"<FLOAT,[24,24,3,3]>{Tensor(...)},
                %"stem.features.2.0.block.0.0.weight"<FLOAT,[96,24,3,3]>{Tensor(...)},
                %"stem.features.2.0.block.1.0.weight"<FLOAT,[48,96,1,1]>{Tensor(...)},
                %"stem.features.2.1.block.0.0.weight"<FLOAT,[192,48,3,3]>{Tensor(...)},
              

In [49]:
# Convert the two-part onnx + onnx data pair into a self-contained ONNX.
# Our model should be below the 2GB boundary for protos.
import onnx
onnx_model = onnx.load("screenshot_checkpoint.onnx")
onnx.save_model(onnx_model, "screenshot.onnx", save_as_external_data=False, all_tensors_to_one_file=True)

In [54]:
import numpy
import onnxruntime as ort
session = ort.InferenceSession("./screenshot.onnx")
blank = numpy.zeros((1, 3, 224, 224), dtype=numpy.float32)
out = torch.softmax(torch.tensor(session.run(["output"], {"input": blank,})[0][0]), axis=0)
print(out.shape)
print(out)

torch.Size([2])
tensor([0.5267, 0.4733])
