In [1]:
import torch
import torch.nn as nn
import os
import torchvision.models as models
import copy
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os

torch.manual_seed(42)

<torch._C.Generator at 0x1cfbdf5dc10>

In [2]:
def print_model_size(mdl):
    torch.save(mdl.state_dict(), "tmp.pt")
    print("%.2f MB" %(os.path.getsize("tmp.pt")/1e6))
    os.remove('tmp.pt')

In [3]:
input_size = (28,28)
mean = [0.485, 0.456, 0.406] 
std = [0.229, 0.224, 0.225]

In [4]:
transform = transforms.Compose([
    transforms.Resize(input_size),  # Resize to a fixed size
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])


In [5]:
class CustomDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []

        # Iterate over each subfolder (assuming one for dogs and one for cats)
        for label, folder_name in enumerate(['dog', 'cat']):
            folder_path = os.path.join(self.root_dir, folder_name)
            for file_name in os.listdir(folder_path):
                file_path = os.path.join(folder_path, file_name)
                
                # Try opening the file with PIL to check if it's a valid image
                try:
                    with Image.open(file_path) as img:
                        
                        if img.mode != 'RGB':
                            print(f"Skipping {file_path} because it does not have 3 channels (RGB)")
                            continue

                        self.image_paths.append(file_path)
                        self.labels.append(label)
                except Exception as e:
                    print(f"Skipping {file_path} due to error: {e}")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label = self.labels[idx]
        
        with Image.open(image_path) as img:

            if img.mode != 'RGB':
                img = img.convert('RGB')

            if self.transform:
                img = self.transform(img)
            
        return img, label


In [6]:
dataset = CustomDataset(root_dir='D:\Embedded ML\QAT-MobileNetV2\PetImages', transform=transform)

# Calculate split sizes
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size

# Split dataset into train and test
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

# Create DataLoader for train and test sets
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


Skipping D:\Embedded ML\QAT-MobileNetV2\PetImages\dog\10158.jpg because it does not have 3 channels (RGB)
Skipping D:\Embedded ML\QAT-MobileNetV2\PetImages\dog\10401.jpg because it does not have 3 channels (RGB)
Skipping D:\Embedded ML\QAT-MobileNetV2\PetImages\dog\10747.jpg because it does not have 3 channels (RGB)
Skipping D:\Embedded ML\QAT-MobileNetV2\PetImages\dog\10797.jpg because it does not have 3 channels (RGB)
Skipping D:\Embedded ML\QAT-MobileNetV2\PetImages\dog\11285.jpg because it does not have 3 channels (RGB)
Skipping D:\Embedded ML\QAT-MobileNetV2\PetImages\dog\11410.jpg because it does not have 3 channels (RGB)
Skipping D:\Embedded ML\QAT-MobileNetV2\PetImages\dog\11675.jpg because it does not have 3 channels (RGB)
Skipping D:\Embedded ML\QAT-MobileNetV2\PetImages\dog\11702.jpg due to error: cannot identify image file 'D:\\Embedded ML\\QAT-MobileNetV2\\PetImages\\dog\\11702.jpg'
Skipping D:\Embedded ML\QAT-MobileNetV2\PetImages\dog\11849.jpg because it does not have 3 



Skipping D:\Embedded ML\QAT-MobileNetV2\PetImages\dog\9078.jpg because it does not have 3 channels (RGB)
Skipping D:\Embedded ML\QAT-MobileNetV2\PetImages\dog\9188.jpg because it does not have 3 channels (RGB)
Skipping D:\Embedded ML\QAT-MobileNetV2\PetImages\dog\Thumbs.db due to error: cannot identify image file 'D:\\Embedded ML\\QAT-MobileNetV2\\PetImages\\dog\\Thumbs.db'
Skipping D:\Embedded ML\QAT-MobileNetV2\PetImages\cat\10125.jpg because it does not have 3 channels (RGB)
Skipping D:\Embedded ML\QAT-MobileNetV2\PetImages\cat\10501.jpg because it does not have 3 channels (RGB)
Skipping D:\Embedded ML\QAT-MobileNetV2\PetImages\cat\10820.jpg because it does not have 3 channels (RGB)
Skipping D:\Embedded ML\QAT-MobileNetV2\PetImages\cat\11095.jpg because it does not have 3 channels (RGB)
Skipping D:\Embedded ML\QAT-MobileNetV2\PetImages\cat\11210.jpg because it does not have 3 channels (RGB)
Skipping D:\Embedded ML\QAT-MobileNetV2\PetImages\cat\11565.jpg because it does not have 3 ch

In [7]:
for images, labels in train_loader:
    print(images.shape)
    print(labels.shape)
    break

torch.Size([32, 3, 28, 28])
torch.Size([32])


In [8]:
def train_epoch(model, criterion, optimizer, data_loader, device,epoch):
    model.train()
    
    epoch_loss = 0.0
    num_batches = len(data_loader)
    
    for batch_idx, (image, target) in enumerate(data_loader):
        image, target = image.to(device), target.to(device)
        
        # Forward pass
        output = model(image)
        
        # Calculate loss
        loss = criterion(output, target)
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Accumulate batch loss
        epoch_loss += loss.item()
        
        # Print batch loss (optional)
        # print(f"Batch [{batch_idx + 1}/{num_batches}], Loss: {loss.item():.4f}")
    
    # Calculate average epoch loss
    avg_epoch_loss = epoch_loss / num_batches
    print(f"Epoch = {epoch+1} || Training Loss: {avg_epoch_loss:.4f}")
    

In [9]:
def evaluate(model, criterion, data_loader, device,epoch):
    
    model.eval()
    
    epoch_loss = 0.0
    
    correct_predictions = 0
    total_predictions = 0
    
    num_batches = len(data_loader)
    
    with torch.no_grad():
       
        for image, target in data_loader:
            image, target = image.to(device), target.to(device)
            output = model(image)
            loss = criterion(output, target)
            # Accumulate batch loss
            epoch_loss += loss.item()
            
            # Calculate accuracy
            _, predicted = torch.max(output, 1)  # Get the predicted class index
            correct_predictions += (predicted == target).sum().item()
            total_predictions += target.size(0)
            
    # Calculate average epoch loss
    avg_epoch_loss = epoch_loss / num_batches
    accuracy = correct_predictions / total_predictions
    
    print(f"Epoch = {epoch+1} || Test Loss: {avg_epoch_loss:.4f} || Test Accuracy: {accuracy:.4f}")
        

In [10]:
class MobileNet(torch.nn.Module):
    def __init__(self):
        super(MobileNet, self).__init__()
        self.model = models.mobilenet_v2(weights='MobileNet_V2_Weights.DEFAULT')  
        
        # for param in self.model.parameters():
        #     param.requires_grad = False
        
        self.model.classifier[1] = nn.Sequential(
            nn.Linear(in_features=self.model.classifier[1].in_features,out_features=512),
            nn.LeakyReLU(negative_slope=0.02,inplace=False),
            nn.BatchNorm1d(num_features=512),
            nn.Dropout(p=0.4,inplace=False),
            nn.Linear(in_features=512,out_features=2),
            nn.Softmax(dim=1))
        
        # print(self.model)

    def forward(self, x):
        x = self.model(x)
        return x

In [11]:
model = MobileNet()

In [12]:
print_model_size(model)

11.76 MB


In [13]:
import platform
chip = platform.processor()

# if chip == 'arm':
#     backend = 'qnnpack'
# elif chip in ['x86_64', 'i386']:
#     backend = 'fbgemm'
# else:
#     raise SystemError("Backend is not supported")

# print(f"Using {backend} backend engine for {chip} CPU")

backend = 'x86'

torch.backends.quantized.engine = backend

In [14]:

from torch.quantization.quantize_fx import prepare_fx, convert_fx,prepare_qat_fx

example_inputs = (torch.randn(1, 3, 28, 28),)
qconfig = {
    "": torch.quantization.get_default_qat_qconfig(backend),
    "module_name": [
    #    ("features.13", None),    
    #    ("features.14", None),
    #    ("features.15", None),
    #    ("features.16", None),
    #    ("features.17", None),
    ]
}
#model_prepared = prepare_fx(model.eval(), qconfig,example_inputs)
model_prepared = prepare_qat_fx(model.train(), qconfig, example_inputs)

Please use torch.ao.quantization.get_default_qconfig_mapping or torch.ao.quantization.get_default_qat_qconfig_mapping. Example:
    qconfig_mapping = get_default_qconfig_mapping("fbgemm")
    model = prepare_fx(model, qconfig_mapping, example_inputs)


In [15]:
criterion = nn.CrossEntropyLoss(reduction='mean')
optimizer = torch.optim.AdamW(model_prepared.parameters(), lr = 0.0001)

In [16]:
for nepoch in range(20):
    train_epoch(model_prepared, criterion, optimizer, train_loader, torch.device('cpu'),nepoch)
    model_quantized = copy.deepcopy(model_prepared)
    model_quantized = convert_fx(model_quantized.eval())
    evaluate(model_quantized,criterion, test_loader,torch.device('cpu'),nepoch)
    



Epoch = 1 || Training Loss: 0.6354
Epoch = 1 || Test Loss: 0.5938 || Test Accuracy: 0.6952
Epoch = 2 || Training Loss: 0.5872
Epoch = 2 || Test Loss: 0.5731 || Test Accuracy: 0.7175
Epoch = 3 || Training Loss: 0.5638
Epoch = 3 || Test Loss: 0.5601 || Test Accuracy: 0.7345


KeyboardInterrupt: 

In [None]:
model_quantized = copy.deepcopy(model_prepared)
model_quantized = convert_fx(model_quantized.eval())
evaluate(model_quantized,criterion, data_loader_test,torch.device('cpu'),0)


Epoch = 1 || Test Loss: 0.5069 || Test Accuracy: 0.7930


In [None]:
print_model_size(model_quantized)

3.30 MB


In [None]:
traced_net = torch.jit.trace(model_quantized, torch.randn(1,3,28,28))

torch.jit.save(traced_net,'QATDogCatMobileNetV2.pt')