In [1]:
import torch
import pandas as pd
import re
import spacy
from torchtext.data.utils import get_tokenizer
import torch.nn as nn
import torch.nn.functional as F
import random
import torch.optim as optim
import torchtext
import time

In [2]:
from models import LSTMNet
import site
import os
os.environ['SP_DIR'] = '/opt/conda/lib/python3.11/site-packages'

accuracies = []

# nlp library of Pytorch
from torchtext import data

import warnings as wrn
wrn.filterwarnings('ignore')


SEED = 2021

torch.manual_seed(SEED)
torch.backends.cuda.deterministic = True

In [3]:
def clean_text(text):
    cleaned_text = re.sub(r'[^A-Za-z0-9]+', ' ', str(text))
    return cleaned_text

# Load and preprocess the data files
def load_and_preprocess(file_path):
    df = pd.read_csv(file_path, header=None, delimiter='\t') # Assuming tab-separated values in .data files
    df[1] = df[1].apply(clean_text) # Assuming the text is in the second column
    cleaned_file_path = file_path.replace('.data', '_cleaned.data')
    df.to_csv(cleaned_file_path, index=False, header=False)
    return cleaned_file_path

In [4]:
cleaned_train_file = load_and_preprocess('./data/train.data')
cleaned_valid_file = load_and_preprocess('./data/valid.data')
cleaned_test_file = load_and_preprocess('./data/test.data')

In [5]:
spacy_en = spacy.load('en_core_web_sm')

def spacy_tokenizer(text):
    return [tok.text for tok in spacy_en.tokenizer(text)]

LABEL = data.LabelField()
TEXT = data.Field(tokenize=spacy_tokenizer, batch_first=True, include_lengths=True)
fields = [("label", LABEL), ("text", TEXT)]

training_data = data.TabularDataset(path=cleaned_train_file, format="csv", fields=fields, skip_header=True)
validation_data = data.TabularDataset(path=cleaned_valid_file, format="csv", fields=fields, skip_header=True)
test_data = data.TabularDataset(path=cleaned_test_file, format="csv", fields=fields, skip_header=True)

print(vars(training_data.examples[0]))

train_data,valid_data = training_data.split(split_ratio=0.75,
                                            random_state=random.seed(SEED))

TEXT.build_vocab(train_data,
                 min_freq=5)

LABEL.build_vocab(train_data)
# Count the number of instances per class
label_counts = {LABEL.vocab.itos[i]: LABEL.vocab.freqs[LABEL.vocab.itos[i]] for i in range(len(LABEL.vocab))}
print("Number of instances per class:", label_counts)


print("Size of text vocab:",len(TEXT.vocab))

print("Size of label vocab:",len(LABEL.vocab))

TEXT.vocab.freqs.most_common(10)

# Creating GPU variable
#device = torch.device("cuda")
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cuda')

Sent_SIZE=32
print("Batch size initialized")


{'label': '1', 'text': ['AFP', 'US', 'presidential', 'hopeful', 'John', 'Kerry', 'slammed', 'President', 'George', 'W', 'Bush', 's', 'North', 'Korea', 'policy', 'amid', 'reports', 'of', 'a', 'mysterious', 'blast', 'in', 'the', 'isolated', 'Asian', 'state']}
Number of instances per class: {'2': 22550, '1': 22526, '3': 22497, '4': 22426}
Size of text vocab: 24171
Size of label vocab: 4
Batch size initialized


In [6]:
train_iterator,validation_iterator = data.BucketIterator.splits(
    (train_data,valid_data),
    batch_size = Sent_SIZE,
    # Sort key is how to sort the samples
    sort_key = lambda x:len(x.text),
    sort_within_batch = True,
    device = device
)

test_iterator = data.BucketIterator(
    test_data,
    batch_size=Sent_SIZE,
    sort_key=lambda x: len(x.text),
    sort_within_batch=True,
    device=device
)

In [7]:
SIZE_OF_VOCAB = len(TEXT.vocab)
EMBEDDING_DIM = 100
NUM_HIDDEN_NODES = 100
NUM_OUTPUT_NODES = len(LABEL.vocab)
NUM_LAYERS = 1
BIDIRECTION = False
DROPOUT = 0.2
BIT_WIDTH = 32
LSTM_BITWIDTH = 8

In [8]:
print(SIZE_OF_VOCAB)
print(NUM_OUTPUT_NODES)

24171
4


In [9]:
model = LSTMNet(SIZE_OF_VOCAB, EMBEDDING_DIM, NUM_HIDDEN_NODES, NUM_OUTPUT_NODES, NUM_LAYERS, BIDIRECTION, DROPOUT, BIT_WIDTH, LSTM_BITWIDTH)
criterion = nn.CrossEntropyLoss()
#criterion = nn.NLLLoss()

print(torch.cuda.is_available())

model = model.to(device)
optimizer = optim.Adam(model.parameters(),lr=1e-3)
#criterion = nn.BCELoss()
#criterion = criterion.to(device)

model

True


LSTMNet(
  (embedding): Embedding(24171, 100)
  (lstm): LSTM(100, 100, batch_first=True, dropout=0.2)
  (fc): Linear(in_features=100, out_features=4, bias=False)
  (sigmoid): Sigmoid()
)

In [10]:
def multi_class_accuracy(preds, y):
    _, predicted = torch.max(preds, 1)
    correct = (predicted == y).float()
    acc = correct.sum() / len(correct)
    return acc

def train(model,iterator,optimizer,criterion):
    
    epoch_loss = 0.0
    epoch_acc = 0.0
    
    model.train()
    
    for batch in iterator:
        
        # cleaning the cache of optimizer
        optimizer.zero_grad()
        
        text,text_lengths = batch.text
        #print("Text Length:", text_lengths[0].item())
        global Sent_SIZE
        #Sent_SIZE=text_lengths[0].item()
        #print("Sent Length:", Sent_SIZE)
        #print("Iterator Batch Size:", batch.batch_size)
        batch.batch_size=Sent_SIZE
        #print("Iterator Batch Size:", batch.batch_size)
        iterator = data.BucketIterator(
            train_data,
            batch_size=Sent_SIZE,
            sort_key=lambda x: len(x.text),
            sort_within_batch=True,
            device=device
        )
        
        # forward propagation and squeezing
        predictions = model(text,text_lengths).squeeze()
        
        # computing loss / backward propagation
        loss = criterion(predictions, batch.label)
        #loss = criterion(predictions,batch.type)
        loss.backward()
        
        # accuracy
        acc = multi_class_accuracy(predictions,batch.label)
        
        # updating params
        optimizer.step()
        
        epoch_loss += loss.item()
        epoch_acc += acc.item()
        
    # It'll return the means of loss and accuracy
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

def evaluate(model,iterator,criterion):
    
    epoch_loss = 0.0
    epoch_acc = 0.0
    
    # deactivate the dropouts
    model.eval()
    
    # Sets require_grad flat False
    with torch.no_grad():
        for batch in iterator:
            text,text_lengths = batch.text
            
            predictions = model(text,text_lengths).squeeze()
              
            #compute loss and accuracy
            loss = criterion(predictions, batch.label)
            acc = multi_class_accuracy(predictions, batch.label)
            
            #keep track of loss and accuracy
            epoch_loss += loss.item()
            epoch_acc += acc.item()
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

def infer(model, iterator, criterion):
    epoch_loss = 0.0
    epoch_acc = 0.0
    
    model.eval()
    
    with torch.no_grad():
        for batch in iterator:
            text, text_lengths = batch.text
            predictions = model(text, text_lengths).squeeze()
            loss = criterion(predictions, batch.label)
            acc = multi_class_accuracy(predictions, batch.label)
            epoch_loss += loss.item()
            epoch_acc += acc.item()
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

In [11]:
# Create dataset and iterator for test data
test_fields = [("label", LABEL), ("text", TEXT)]
test_data = data.TabularDataset(path=cleaned_test_file, format="csv", fields=test_fields, skip_header=True)

In [12]:
checkpoint = torch.load('./modelParameter_LSTM_AGNEWS_FP.pth')
model.load_state_dict(checkpoint)

<All keys matched successfully>

In [13]:
for name, param in model.named_parameters():
    print(name)

embedding.weight
lstm.weight_ih_l0
lstm.weight_hh_l0
lstm.bias_ih_l0
lstm.bias_hh_l0
fc.weight


In [14]:
from brevitas.quant import Int8ActPerTensorFixedPoint, Int8WeightPerTensorFixedPoint
from brevitas.graph.quantize import preprocess_for_quantize
from brevitas_examples.imagenet_classification.ptq.ptq_common import quantize_model
from brevitas_examples.imagenet_classification.ptq.ptq_common import calibrate

In [15]:
model = preprocess_for_quantize(
            model,
            equalize_iters=20,
            equalize_merge_bias=True,
            merge_bn=True,
            channel_splitting_ratio=0.0,
            channel_splitting_split_input=False)

1


In [16]:
dtype = getattr(torch, 'float')

In [51]:
quant_model = quantize_model(
        model,
        dtype=dtype,
        device=device,
        backend='layerwise',
        scale_factor_type='float_scale',
        bias_bit_width=32,
        weight_bit_width=4,
        weight_narrow_range=False,
        weight_param_method='stats',
        weight_quant_granularity='per_tensor',
        weight_quant_type='sym',
        layerwise_first_last_bit_width=4,
        act_bit_width=4,
        act_param_method='stats',
        act_quant_percentile=99.99,
        act_quant_type='sym',
        quant_format='int',
        layerwise_first_last_mantissa_bit_width=4,
        layerwise_first_last_exponent_bit_width=3,
        weight_mantissa_bit_width=4,
        weight_exponent_bit_width=3,
        act_mantissa_bit_width=4,
        act_exponent_bit_width=3)

In [52]:
for name, param in quant_model.named_parameters():
    print(name)

embedding.weight
lstm.weight_ih_l0
lstm.weight_hh_l0
lstm.bias_ih_l0
lstm.bias_hh_l0
fc.weight
fc.input_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl.value
fc.weight_quant.tensor_quant.scaling_impl.value


In [53]:
def calibrate(model, iterator):
    model.eval()
    with torch.no_grad():
        for batch in iterator:
            text, text_lengths = batch.text
            _ = model(text, text_lengths)

In [54]:
calibrate(quant_model, train_iterator)

In [55]:
test_loss, test_acc = infer(quant_model, test_iterator, criterion)
print(f'\t Test. Loss: {test_loss:.3f} |  Test. Acc: {test_acc*100:.2f}%')

	 Test. Loss: 0.713 |  Test. Acc: 87.92%


In [60]:
from brevitas_examples.imagenet_classification.ptq.utils import get_model_config
model_config = get_model_config(quant_model)
model_config

{'inception_preprocessing': False,
 'resize_shape': 256,
 'center_crop_shape': 224}

In [62]:
center_crop_shape = model_config['center_crop_shape']
img_shape = center_crop_shape

In [64]:
device, dtype = next(model.parameters()).device, next(model.parameters()).dtype
ref_input = torch.ones(1, 3, img_shape, img_shape, device=device, dtype=dtype)