Skip to content

Commit

Permalink
Merge branch 'master' into fix-tests
Browse files Browse the repository at this point in the history
  • Loading branch information
khizirsiddiqui committed Jun 11, 2021
2 parents 5d4bd7d + 1a1b0ec commit b97abd1
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 34 deletions.
5 changes: 0 additions & 5 deletions .github/workflows/python-package-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,6 @@ jobs:
pip install build
- name: Build package
run: python -m build
- name: Black
run: |
# stop the build if there are Python syntax errors or undefined names
black --check KD_Lib
black --check tests
- name: Test with pytest
run: |
pytest
1 change: 0 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ install:
- pip install -U tox-travis codecov black
- python setup.py install


jobs:
include:
# Deploy Documentation
Expand Down
40 changes: 19 additions & 21 deletions KD_Lib/KD/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,27 +53,25 @@ def __init__(
if self.log:
self.writer = SummaryWriter(logdir)

try:
torch.Tensor(0).to(device)
self.device = device
except:
print(
"Either an invalid device or CUDA is not available. Defaulting to CPU."
)
if device == "cpu":
self.device = torch.device("cpu")
elif device == "cuda":
if torch.cuda.is_available():
self.device = torch.device("cuda")
else:
print(
"Either an invalid device or CUDA is not available. Defaulting to CPU."
)
self.device = torch.device("cpu")

try:
if teacher_model:
self.teacher_model = teacher_model.to(self.device)
except:
else:
print("Warning!!! Teacher is NONE.")

self.student_model = student_model.to(self.device)
try:
self.loss_fn = loss_fn.to(self.device)
self.ce_fn = nn.CrossEntropyLoss().to(self.device)
except:
self.loss_fn = loss_fn
self.ce_fn = nn.CrossEntropyLoss()
print("Warning: Loss Function can't be moved to device.")
self.loss_fn = loss_fn.to(self.device)
self.ce_fn = nn.CrossEntropyLoss().to(self.device)

def train_teacher(
self,
Expand Down Expand Up @@ -142,7 +140,7 @@ def train_teacher(
)

loss_arr.append(epoch_loss)
print(f"Epoch: {ep+1}, Loss: {epoch_loss}, Accuracy: {epoch_acc}")
print("Epoch: {}, Loss: {}, Accuracy: {}".format(ep+1, epoch_loss, epoch_acc))

self.post_epoch_call(ep)

Expand Down Expand Up @@ -224,7 +222,7 @@ def _train_student(
)

loss_arr.append(epoch_loss)
print(f"Epoch: {ep+1}, Loss: {epoch_loss}, Accuracy: {epoch_acc}")
print("Epoch: {}, Loss: {}, Accuracy: {}".format(ep+1, epoch_loss, epoch_acc))

self.student_model.load_state_dict(self.best_student_model_weights)
if save_model:
Expand Down Expand Up @@ -290,7 +288,7 @@ def _evaluate_model(self, model, verbose=True):

if verbose:
print("-" * 80)
print(f"Validation Accuracy: {accuracy}")
print("Validation Accuracy: {}".format(accuracy))
return outputs, accuracy

def evaluate(self, teacher=False):
Expand All @@ -315,8 +313,8 @@ def get_parameters(self):
student_params = sum(p.numel() for p in self.student_model.parameters())

print("-" * 80)
print(f"Total parameters for the teacher network are: {teacher_params}")
print(f"Total parameters for the student network are: {student_params}")
print("Total parameters for the teacher network are: {}".format(teacher_params))
print("Total parameters for the student network are: {}".format(student_params))

def post_epoch_call(self, epoch):
"""
Expand Down
35 changes: 28 additions & 7 deletions tests/test_KD_Lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
ProbShift,
LabelSmoothReg,
DML,
BaseClass
)

from KD_Lib.models import (
Expand Down Expand Up @@ -94,6 +95,10 @@ def test_resnet():
ResNet50(params)
ResNet101(params)
ResNet152(params)
ResNet34(params, att=True)
ResNet34(params, mean=True)
ResNet101(params, att=True)
ResNet101(params, mean=True)


def test_attention_model():
Expand Down Expand Up @@ -159,6 +164,22 @@ def test_LSTMNet():
# Strategy TESTS
#

def test_BaseClass()
teac = Shallow(hidden_size=400)
stud = Shallow(hidden_size=100)

t_optimizer = optim.SGD(teac.parameters(), 0.01)
s_optimizer = optim.SGD(stud.parameters(), 0.01)

distiller = BaseClass(
teac, stud, train_loader, test_loader, t_optimizer, s_optimizer, log=True
)

distiller.train_teacher(epochs=1, plot_losses=True, save_model=True)
distiller.train_student(epochs=1, plot_losses=True, save_model=True)
distiller.evaluate(teacher=False)
distiller.get_parameters()


def test_VanillaKD():
teac = Shallow(hidden_size=400)
Expand All @@ -168,11 +189,11 @@ def test_VanillaKD():
s_optimizer = optim.SGD(stud.parameters(), 0.01)

distiller = VanillaKD(
teac, stud, train_loader, test_loader, t_optimizer, s_optimizer
teac, stud, train_loader, test_loader, t_optimizer, s_optimizer, log=True
)

distiller.train_teacher(epochs=1, plot_losses=False, save_model=False)
distiller.train_student(epochs=1, plot_losses=False, save_model=False)
distiller.train_teacher(epochs=1, plot_losses=True, save_model=True)
distiller.train_student(epochs=1, plot_losses=True, save_model=True)
distiller.evaluate(teacher=False)
distiller.get_parameters()

Expand Down Expand Up @@ -289,8 +310,8 @@ def test_SelfTraining():


def test_mean_teacher():
teacher_params = [4, 4, 8, 4, 4]
student_params = [4, 4, 4, 4, 4]
teacher_params = [16, 16, 32, 16, 16]
student_params = [16, 16, 16, 16, 16]
teacher_model = ResNet50(teacher_params, 1, 10, mean=True)
student_model = ResNet18(student_params, 1, 10, mean=True)

Expand Down Expand Up @@ -488,7 +509,7 @@ def test_lottery_tickets():
teacher_params = [4, 4, 8, 4, 4]
teacher_model = ResNet50(teacher_params, 1, 10, True)
pruner = Lottery_Tickets_Pruner(teacher_model, train_loader, test_loader)
pruner.prune(num_iterations=1, train_iterations=1, valid_freq=1, print_freq=1)
pruner.prune(num_iterations=2, train_iterations=2, valid_freq=1, print_freq=1)


#
Expand Down Expand Up @@ -539,6 +560,6 @@ def test_qat_quantization():
model.fc.out_features = 10
optimizer = torch.optim.Adam(model.parameters())
quantizer = QAT_Quantizer(model, cifar_trainloader, cifar_testloader, optimizer)
quantized_model = quantizer.quantize(1, 1, -1, -1)
quantized_model = quantizer.quantize(1, 1, 1, 1)
quantizer.get_model_sizes()
quantizer.get_performance_statistics()

0 comments on commit b97abd1

Please sign in to comment.