In [6]:
import torch
import torch.nn as nn
import torchvision.models 
import torchvision.transforms as transforms

from datapipeline.asl_image_data_module import ASLImageDataModule
from lightning.pytorch.loggers import WandbLogger
from models.asl_model import ASLModel
from models.training import train, PROJECT_NAME, ENTITY_NAME

In [7]:
img_size = 224

data_transforms = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(5),
    transforms.RandomAffine(0, shear=10, scale=(0.8, 1.2)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
datamodule = ASLImageDataModule(path="/exchange/dspro2/silent-speech/ASL_Dataset", transforms=data_transforms)

In [8]:
class ASLEfficientNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = torchvision.models.efficientnet_b0()
        self.model.requires_grad_(False)
        self.model.classifier = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(self.model.classifier[1].in_features, 28),
            nn.Softmax(dim=-1)
        )

    def forward(self, x):
        return self.model(x)

In [9]:
torch_model = ASLEfficientNet()
model = ASLModel(model=torch_model, criterion=nn.CrossEntropyLoss(), optimizer=torch.optim.Adam(torch_model.parameters(), lr=1e-4))

logger = WandbLogger(name="efficientnet-1", project=PROJECT_NAME)

train(model, datamodule, logger=logger)

Seed set to 42
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mv8-luky[0m ([33mdspro2-silent-speech[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


FileNotFoundError: [WinError 3] The system cannot find the path specified: '/exchange/dspro2/silent-speech/ASL_Dataset/Train'