# IT Service Desk Ticket Classifier - Enhanced Training

**7,200 Training Samples | 12 Categories | DistilBERT + Focal Loss**

### Quick Start
1. Runtime > Change runtime type > GPU (T4)
2. Run all cells
3. Download trained model at the end

In [None]:
!pip install -q transformers torch scikit-learn pandas numpy tqdm huggingface_hub

In [None]:
import torch
print(f'PyTorch: {torch.__version__}')
print(f'CUDA: {torch.cuda.is_available()}')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if torch.cuda.is_available(): print(f'GPU: {torch.cuda.get_device_name(0)}')

## Generate Enhanced Training Data (7,200 samples)

In [None]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split

np.random.seed(42)

# Enhanced templates - 10 per category
TEMPLATES = {
    'Hardware': [
        ('Laptop screen flickering', 'My laptop screen has been flickering intermittently for the past few days.'),
        ('Keyboard not responding', 'Several keys on my keyboard have stopped working.'),
        ('Monitor display issues', 'My external monitor shows distorted colors and lines.'),
        ('Laptop overheating', 'My laptop is getting extremely hot and shuts down.'),
        ('Mouse cursor jumping', 'The mouse cursor jumps around randomly.'),
        ('Docking station problems', 'Docking station is not detecting external monitors.'),
        ('Battery not charging', 'Laptop battery is not charging even when plugged in.'),
        ('USB ports not working', 'None of the USB ports recognize any devices.'),
        ('Webcam not functioning', 'Built-in webcam shows black screen during calls.'),
        ('Laptop won\'t turn on', 'Laptop won\'t power on at all, no lights.'),
    ],
    'Software': [
        ('Microsoft Office crashing', 'Excel crashes every time I open large files.'),
        ('Application installation failed', 'Unable to install the new software, getting error.'),
        ('Software update problems', 'Windows Update keeps failing with error.'),
        ('VPN client not connecting', 'VPN client shows connection timeout error.'),
        ('Browser running slowly', 'Chrome is extremely slow and freezes frequently.'),
        ('Adobe Acrobat issues', 'PDF files won\'t open in Adobe Acrobat.'),
        ('Zoom crashes during calls', 'Zoom crashes when sharing screen.'),
        ('Teams not loading', 'Microsoft Teams stuck on loading screen.'),
        ('Outlook freezing', 'Outlook becomes unresponsive when switching folders.'),
        ('Software license expired', 'Getting message that license has expired.'),
    ],
    'Network': [
        ('Cannot connect to WiFi', 'Cannot connect to corporate WiFi network.'),
        ('Internet connection dropping', 'Internet drops every 10-15 minutes.'),
        ('VPN disconnects frequently', 'VPN connection drops randomly during work.'),
        ('Slow network performance', 'Network speeds are extremely slow.'),
        ('Cannot access internal sites', 'Cannot access internal websites.'),
        ('Remote desktop connection failed', 'Unable to connect via Remote Desktop.'),
        ('Network drive not accessible', 'Cannot access the shared network drive.'),
        ('Wireless keeps disconnecting', 'Wireless disconnects when walking around.'),
        ('Cannot ping servers', 'Cannot ping any internal servers.'),
        ('Ethernet not detected', 'Laptop doesn\'t detect ethernet cable.'),
    ],
    'Security': [
        ('Suspicious email received', 'Received phishing email asking for credentials.'),
        ('Potential malware detected', 'Antivirus detected threat but couldn\'t remove it.'),
        ('Unauthorized access attempt', 'Alert about login from unknown location.'),
        ('Security certificate error', 'Getting certificate warnings on internal sites.'),
        ('Ransomware attack', 'Files encrypted and there\'s a ransom note.'),
        ('Lost company laptop', 'Lost laptop on commute, contains sensitive data.'),
        ('Password compromised', 'Password may be compromised, suspicious activity.'),
        ('USB device with malware', 'Plugged in USB that may contain malware.'),
        ('Badge stolen', 'Employee badge stolen, needs deactivation.'),
        ('Data breach notification', 'Vendor experienced a data breach.'),
    ],
    'Access Management': [
        ('Need SharePoint access', 'Need access to Marketing SharePoint site.'),
        ('Account locked out', 'AD account locked, cannot log in.'),
        ('MFA not working', 'Multi-factor authentication not sending codes.'),
        ('Password reset needed', 'Forgot password, self-service not working.'),
        ('New employee access setup', 'Set up access for new team member.'),
        ('VPN access request', 'Need VPN access for working from home.'),
        ('Admin rights request', 'Need temporary admin rights for software.'),
        ('GitHub access needed', 'Need access to company GitHub.'),
        ('SSO not working', 'Single sign-on not working for Salesforce.'),
        ('Role change access update', 'Moved departments, need access updated.'),
    ],
    'Email': [
        ('Outlook not receiving emails', 'Haven\'t received emails in 3 hours.'),
        ('Cannot send large attachments', 'Error when sending attachments over 5MB.'),
        ('Calendar not syncing', 'Calendar not syncing with mobile phone.'),
        ('Shared mailbox issues', 'Cannot access team shared mailbox.'),
        ('Email signature not displaying', 'Signature shows incorrectly to recipients.'),
        ('Out of office not working', 'Auto-reply not sent to external contacts.'),
        ('Emails going to spam', 'Important emails going to spam folder.'),
        ('Distribution list not working', 'Emails to distribution list not delivered.'),
        ('Cannot recall email', 'Need to recall email sent to wrong person.'),
        ('Mailbox full', 'Getting mailbox full warnings.'),
    ],
    'Database': [
        ('SQL query performance', 'Query that took seconds now takes 10 minutes.'),
        ('Database connection timeout', 'Applications timing out connecting to DB.'),
        ('Need data restoration', 'Accidentally deleted records, need restored.'),
        ('Database not starting', 'Oracle instance won\'t start after reboot.'),
        ('Storage space running low', 'Database tablespace at 95% capacity.'),
        ('Query returning wrong results', 'Stored procedure returning incorrect data.'),
        ('Replication lag issues', 'Significant lag between primary and replica.'),
        ('Need database user account', 'Create read-only account for reporting.'),
        ('Backup job failing', 'Nightly backup failing for 3 nights.'),
        ('Table lock issues', 'Table locks preventing record updates.'),
    ],
    'Storage': [
        ('OneDrive not syncing', 'OneDrive stopped syncing, shows pending.'),
        ('Network drive full', 'Network drive full, cannot save files.'),
        ('Files disappeared', 'Important files disappeared from folder.'),
        ('Need storage quota increase', 'Exceeded OneDrive quota, need increase.'),
        ('SharePoint storage limit', 'SharePoint site reached storage limit.'),
        ('File version history missing', 'Previous versions not available.'),
        ('Cannot download large files', 'Timeout downloading large SharePoint files.'),
        ('Storage performance slow', 'Accessing network drive extremely slow.'),
        ('File permissions changed', 'Lost access to files I was using.'),
        ('Deleted files recovery', 'Need to recover deleted files from last week.'),
    ],
    'Printing': [
        ('Printer not appearing', 'Floor printer not in available printers.'),
        ('Print jobs stuck in queue', 'Print jobs stuck and won\'t print.'),
        ('Poor print quality', 'Prints coming out with streaks.'),
        ('Cannot print in color', 'Color printing option not available.'),
        ('Double-sided printing not working', 'Cannot print on both sides.'),
        ('Printer offline status', 'Printer shows offline even though on.'),
        ('Scanning not working', 'Cannot scan documents to email.'),
        ('Secure print not releasing', 'Secure print won\'t release at printer.'),
        ('New printer installation', 'Please install new department printer.'),
        ('Printer driver issues', 'Print errors after Windows update.'),
    ],
    'Backup': [
        ('Backup job failed', 'Weekly backup failed with error.'),
        ('Need file restoration', 'Deleted folder, need restored from backup.'),
        ('Backup storage full', 'Backup storage at capacity.'),
        ('Restore test needed', 'Need disaster recovery test before audit.'),
        ('Backup schedule change', 'Change backup window timing.'),
        ('Incremental backup issues', 'Incremental not capturing changes.'),
        ('Backup verification failed', 'Verification reports corrupted data.'),
        ('Need backup excluded folder', 'Exclude temp directory from backups.'),
        ('Backup agent not running', 'Backup agent on server has stopped.'),
        ('Recovery time too long', 'Last restore took 12 hours, too slow.'),
    ],
    'General Inquiry': [
        ('How to use VPN', 'How do I set up VPN from home?'),
        ('Password policy question', 'What are password requirements?'),
        ('IT documentation location', 'Where is documentation for new system?'),
        ('Software request process', 'What is process to request software?'),
        ('Help desk hours', 'What are IT help desk hours?'),
        ('Equipment return process', 'How do I return old laptop?'),
        ('Conference room booking', 'How to book room with video equipment?'),
        ('Training resources', 'Are there Microsoft 365 training resources?'),
        ('New hire checklist', 'What IT equipment does new hire get?'),
        ('Remote work guidelines', 'What are IT requirements for remote work?'),
    ],
    'Other': [
        ('General IT feedback', 'Feedback about recent IT improvements.'),
        ('IT project consultation', 'Need IT consultation for initiative.'),
        ('IT asset question', 'How to find asset tag on laptop?'),
        ('IT policy clarification', 'Clarification on acceptable use policy.'),
        ('Vendor software inquiry', 'Which vendors approved for cloud storage?'),
        ('IT budget question', 'Process for IT budget requests?'),
        ('Sustainability initiative', 'Any IT e-waste reduction initiatives?'),
        ('IT event support', 'Need IT support for company meeting.'),
        ('Compliance question', 'Data retention requirements for documents?'),
        ('Technology roadmap', 'When upgrading to Windows 12?'),
    ],
}

PREFIXES = ['', 'Urgent: ', 'Please help: ', 'Issue: ', 'Request: ', 'Need assistance: ']
SUFFIXES = ['', ' This is affecting my work.', ' Please help ASAP.', ' Been having this issue.', ' Blocking critical project.', ' Thank you.']

def generate_data(samples_per_cat=600):
    data = []
    for cat, templates in TEMPLATES.items():
        for _ in range(samples_per_cat):
            subj, desc = templates[np.random.randint(len(templates))]
            prefix, suffix = np.random.choice(PREFIXES), np.random.choice(SUFFIXES)
            if np.random.random() < 0.3: subj = subj.lower()
            data.append({'subject': f'{prefix}{subj}', 'description': f'{desc}{suffix}', 'category': cat})
    return pd.DataFrame(data)

df = generate_data(600).sample(frac=1, random_state=42).reset_index(drop=True)
train_df, temp = train_test_split(df, test_size=0.2, random_state=42, stratify=df['category'])
val_df, test_df = train_test_split(temp, test_size=0.5, random_state=42, stratify=temp['category'])

print(f'Dataset: Train={len(train_df)}, Val={len(val_df)}, Test={len(test_df)}')
print(f'Categories: {df["category"].nunique()}')
print(df['category'].value_counts())

## Model and Training Setup

In [None]:
import re
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
import torch.nn as nn
import torch.nn.functional as F
from transformers import DistilBertModel
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR
from sklearn.metrics import f1_score, accuracy_score
from tqdm.auto import tqdm

class TicketPreprocessor:
    def __init__(self): self._email = re.compile(r'\b[\w.-]+@[\w.-]+\.\w+\b')
    def clean(self, t): return ' '.join(self._email.sub('[EMAIL]', str(t or '')).lower().split())
    def combine(self, s, d): return f'[SUBJECT] {self.clean(s)} [SEP] [DESCRIPTION] {self.clean(d)}'

class TicketDataset(Dataset):
    def __init__(self, df, tok, lm, ml=256):
        self.df, self.tok, self.lm, self.ml, self.pp = df.reset_index(drop=True), tok, lm, ml, TicketPreprocessor()
    def __len__(self): return len(self.df)
    def __getitem__(self, i):
        r = self.df.iloc[i]
        e = self.tok(self.pp.combine(r['subject'], r['description']), truncation=True, max_length=self.ml, padding='max_length', return_tensors='pt')
        return {'input_ids': e['input_ids'].squeeze(), 'attention_mask': e['attention_mask'].squeeze(), 'labels': torch.tensor(self.lm[r['category']])}

class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2.0): super().__init__(); self.alpha, self.gamma = alpha, gamma
    def forward(self, logits, targets):
        ce = F.cross_entropy(logits, targets, reduction='none')
        pt = torch.exp(-ce)
        loss = ((1-pt)**self.gamma) * ce
        if self.alpha is not None: loss = self.alpha.to(logits.device)[targets] * loss
        return loss.mean()

class TicketClassifier(nn.Module):
    def __init__(self, n_classes, model_name='distilbert-base-uncased', dropout=0.3):
        super().__init__()
        self.bert = DistilBertModel.from_pretrained(model_name)
        self.classifier = nn.Sequential(nn.Dropout(dropout), nn.Linear(768, 256), nn.GELU(), nn.Dropout(dropout), nn.Linear(256, n_classes))
    def forward(self, input_ids, attention_mask): return self.classifier(self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :])
    def predict_proba(self, input_ids, attention_mask): return torch.softmax(self.forward(input_ids, attention_mask), dim=-1)

# Setup
class_names = sorted(train_df['category'].unique())
label_map = {n: i for i, n in enumerate(class_names)}
idx_to_label = {v: k for k, v in label_map.items()}
n_classes = len(class_names)

MODEL_NAME = 'distilbert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

train_ds, val_ds, test_ds = TicketDataset(train_df, tokenizer, label_map), TicketDataset(val_df, tokenizer, label_map), TicketDataset(test_df, tokenizer, label_map)
BATCH = 32
train_loader, val_loader, test_loader = DataLoader(train_ds, batch_size=BATCH, shuffle=True), DataLoader(val_ds, batch_size=BATCH), DataLoader(test_ds, batch_size=BATCH)

model = TicketClassifier(n_classes).to(device)
print(f'Model params: {sum(p.numel() for p in model.parameters()):,}')
print(f'Classes: {class_names}')

## Training Loop

In [None]:
EPOCHS, LR, PATIENCE = 10, 2e-5, 3

optimizer = AdamW(model.parameters(), lr=LR, weight_decay=0.01)
scheduler = OneCycleLR(optimizer, max_lr=LR, total_steps=len(train_loader)*EPOCHS)

weights = torch.tensor(1.0 / train_df['category'].value_counts().sort_index().values, dtype=torch.float32)
weights = weights / weights.sum() * n_classes
criterion = FocalLoss(alpha=weights, gamma=2.0)

def train_epoch(model, loader):
    model.train(); loss_sum = 0
    for b in tqdm(loader, leave=False):
        optimizer.zero_grad()
        loss = criterion(model(b['input_ids'].to(device), b['attention_mask'].to(device)), b['labels'].to(device))
        loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); optimizer.step(); scheduler.step()
        loss_sum += loss.item()
    return loss_sum / len(loader)

def evaluate(model, loader):
    model.eval(); preds, labels = [], []
    with torch.no_grad():
        for b in loader:
            preds.extend(model(b['input_ids'].to(device), b['attention_mask'].to(device)).argmax(1).cpu().numpy())
            labels.extend(b['labels'].numpy())
    return {'acc': accuracy_score(labels, preds), 'f1': f1_score(labels, preds, average='macro')}

best_f1, patience_cnt, history = 0, 0, []
print('Training...\n')

for epoch in range(EPOCHS):
    loss = train_epoch(model, train_loader)
    metrics = evaluate(model, val_loader)
    print(f'Epoch {epoch+1}/{EPOCHS} | Loss: {loss:.4f} | Acc: {metrics["acc"]:.4f} | F1: {metrics["f1"]:.4f}')
    history.append({'epoch': epoch+1, 'loss': loss, **metrics})
    if metrics['f1'] > best_f1:
        best_f1, patience_cnt = metrics['f1'], 0
        torch.save(model.state_dict(), 'best_model.pt')
        print(f'  -> Best model saved! F1: {best_f1:.4f}')
    else:
        patience_cnt += 1
        if patience_cnt >= PATIENCE: print(f'\nEarly stopping at epoch {epoch+1}'); break

print(f'\nDone! Best F1: {best_f1:.4f}')

## Evaluation

In [None]:
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

model.load_state_dict(torch.load('best_model.pt')); model.eval()

preds, labels = [], []
with torch.no_grad():
    for b in tqdm(test_loader):
        preds.extend(model.predict_proba(b['input_ids'].to(device), b['attention_mask'].to(device)).argmax(1).cpu().numpy())
        labels.extend(b['labels'].numpy())

print('\n' + '='*60 + '\nCLASSIFICATION REPORT\n' + '='*60)
print(classification_report(labels, preds, target_names=class_names, digits=4))

plt.figure(figsize=(12, 10))
sns.heatmap(confusion_matrix(labels, preds, normalize='true'), annot=True, fmt='.2f', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted'); plt.ylabel('True'); plt.title('Confusion Matrix'); plt.xticks(rotation=45, ha='right'); plt.tight_layout(); plt.show()

## Test Predictions

In [None]:
pp = TicketPreprocessor()

def predict(subj, desc):
    model.eval()
    enc = tokenizer(pp.combine(subj, desc), return_tensors='pt', truncation=True, max_length=256, padding='max_length').to(device)
    with torch.no_grad(): probs = model.predict_proba(enc['input_ids'], enc['attention_mask'])[0].cpu().numpy()
    top3 = probs.argsort()[-3:][::-1]
    print(f'\nSubject: {subj}\nDescription: {desc}\nPredictions:')
    for i, idx in enumerate(top3): print(f'  {i+1}. {idx_to_label[idx]}: {probs[idx]*100:.1f}%')

predict('VPN not connecting', 'Cannot connect to corporate VPN from home, getting timeout.')
predict('Suspicious email', 'Email asking for password, looks like phishing.')
predict('Need SharePoint access', 'Joined new project, need access to SharePoint.')
predict('Laptop screen flickering', 'Screen flickering after Windows update.')

## Save & Download Model

In [None]:
# Save full checkpoint
checkpoint = {
    'model_state_dict': model.state_dict(),
    'class_names': class_names,
    'label_mapping': label_map,
    'best_f1': best_f1,
    'n_classes': n_classes
}
torch.save(checkpoint, 'ticket_classifier.pt')
print('Model saved to ticket_classifier.pt')

# Download
from google.colab import files
files.download('ticket_classifier.pt')

## Upload to Hugging Face (Optional)

Run this after downloading the model to upload directly to HF Hub.

In [None]:
# Uncomment and run after downloading model
# from huggingface_hub import HfApi, create_repo, upload_file
# 
# # Login first: run `!huggingface-cli login` in a cell
# 
# REPO_ID = 'YOUR_USERNAME/ticket-classifier'  # Change this!
# 
# create_repo(repo_id=REPO_ID, repo_type='model', exist_ok=True)
# upload_file(path_or_fileobj='ticket_classifier.pt', path_in_repo='ticket_classifier.pt', repo_id=REPO_ID)
# print(f'Uploaded to https://huggingface.co/{REPO_ID}')

---
## Summary

| Item | Value |
|------|-------|
| Training Samples | 5,760 |
| Categories | 12 |
| Model | DistilBERT |
| Best Val F1 | See training output |

**Next Steps:**
1. Download `ticket_classifier.pt`
2. Upload to Hugging Face
3. Deploy to Render