In [6]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch

model = AutoModelForSeq2SeqLM.from_pretrained("Shubhraweb89/bart_samsum_model")
tokenizer = AutoTokenizer.from_pretrained("Shubhraweb89/bart_samsum_model")




In [7]:
def generate_summary(article_text):
    input_ids = tokenizer(article_text, return_tensors="pt", truncation=True, max_length=1024).input_ids
    output_ids = model.generate(
        input_ids,
        do_sample=True,
        top_k=50,
        top_p=0.95,
        max_length=128,
        num_return_sequences=1
    )
    summary = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    return summary, output_ids


In [8]:
def get_user_feedback(summary):
    print("Generated Summary:\n", summary)
    feedback = int(input("Like = 1, Dislike = 0: "))
    return feedback


In [9]:
import torch.nn.functional as F

def compute_log_probs(input_ids, output_ids):
    outputs = model(input_ids=input_ids, decoder_input_ids=output_ids[:, :-1])
    logits = outputs.logits[:, :-1, :]  # skip the last token
    target_ids = output_ids[:, 1:]      # shift for teacher forcing

    # Cross-entropy manually
    log_probs = F.cross_entropy(
        logits.reshape(-1, logits.size(-1)),
        target_ids.reshape(-1),
        ignore_index=tokenizer.pad_token_id,
        reduction="mean"
    )
    return -log_probs


In [10]:
from torch.optim import Adam
import torch.nn.functional as F

optimizer = Adam(model.parameters(), lr=5e-6)

def compute_log_probs(input_ids, output_ids):
    # Shift decoder input ids and labels
    decoder_input_ids = output_ids[:, :-1]
    labels = output_ids[:, 1:]

    outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
    logits = outputs.logits

    # Safety check: truncate logits to match labels
    logits = logits[:, :labels.size(1), :]

    loss = F.cross_entropy(
        logits.reshape(-1, logits.size(-1)),
        labels.reshape(-1),
        ignore_index=tokenizer.pad_token_id,
        reduction='mean'
    )
    return -loss  # return log-prob

# Your loop
for step in range(5):
    article = input("Paste article text: ")
    
    summary, output_ids = generate_summary(article)
    print("Generated Summary:\n", summary)
    feedback = int(input("Like = 1, Dislike = 0: "))

    input_ids = tokenizer(article, return_tensors="pt", truncation=True, max_length=1024).input_ids

    # Ensure everything is on the same device
    input_ids = input_ids.to(model.device)
    output_ids = output_ids.to(model.device)

    log_prob = compute_log_probs(input_ids, output_ids)
    reward = 1.0 if feedback == 1 else -1.0

    loss = -reward * log_prob
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    print(f"Step {step} → Reward: {reward} → Loss: {loss.item():.4f}")




Generated Summary:
  "Chinese blessing scams" occur worldwide. Chinese blessing scams have been reported worldwide for 25 years. They are targeted Asian women usually Asian women's wealth.


Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Step 0 → Reward: 1.0 → Loss: 1.1465
Generated Summary:
  scam where an elderly lady is convinced family member is cursed and should have her wealth blessed. Authorities are investigating it worldwide.
Step 1 → Reward: -1.0 → Loss: -1.4784
Generated Summary:
 Erin Patterson has been found guilty of three counts of murder and attempted murder. 12-member jury reached the verdict after around six days of deliberation following a 10-week trial in Morwell, an hour's drive from the dining room in Leongatha, Victoria, where the lethal lunch was served in July 2023.
Step 2 → Reward: -1.0 → Loss: -0.3888
Generated Summary:
  mushrooms baked in a Beef Wellington lunch were served to a group of three in Morwell, Victoria. Erin Patterson was convicted of three counts of murder and the attempted murder of the lone survivor. 
Step 3 → Reward: 1.0 → Loss: 0.7419
Generated Summary:
  poison and fabricated a cancer claim in order to get her lunch invitation.
Step 4 → Reward: -1.0 → Loss: -1.3624


In [11]:
save_path = "bart_summarizer_with_rl"

model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)




('bart_summarizer_with_rl\\tokenizer_config.json',
 'bart_summarizer_with_rl\\special_tokens_map.json',
 'bart_summarizer_with_rl\\vocab.json',
 'bart_summarizer_with_rl\\merges.txt',
 'bart_summarizer_with_rl\\added_tokens.json',
 'bart_summarizer_with_rl\\tokenizer.json')

In [9]:
!pip install huggingface_hub





[notice] A new release of pip available: 22.3.1 -> 25.1.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [3]:
!pip install transformers





[notice] A new release of pip available: 22.3.1 -> 25.1.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [13]:
!huggingface-cli login


^C


In [None]:
model.save_pretrained("bart_summarizer_with_rl")
tokenizer.save_pretrained("bart_summarizer_with_rl")


In [12]:
!huggingface-cli repo create my-news-summarizer


Traceback (most recent call last):
  File "E:\CIS_Project\venv\Lib\site-packages\huggingface_hub\utils\_http.py", line 409, in hf_raise_for_status
    response.raise_for_status()
  File "E:\CIS_Project\venv\Lib\site-packages\requests\models.py", line 1026, in raise_for_status
    raise HTTPError(http_error_msg, response=self)
requests.exceptions.HTTPError: 401 Client Error: Unauthorized for url: https://huggingface.co/api/repos/create

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "e:\CIS_Project\venv\Scripts\huggingface-cli.exe\__main__.py", line 7, in <module>
  File "E:\CIS_Project\venv\Lib\site-packages\huggingface_hub\commands\huggingface_cli.py", line 59, in main
    service.run()
  File "E:\CIS_Project\venv\Lib\site-packages\huggingface_hub\commands\repo.py", line 137, in run
    repo_url = self._api.create_rep