In [2]:
from datasets import load_dataset
from langchain.schema import Document
import torch 

ds = load_dataset("darrow-ai/USClassActions")

  from .autonotebook import tqdm as notebook_tqdm


In [17]:
dataset = ds["train"]

from langchain_chroma import Chroma
from langchain.embeddings import HuggingFaceEmbeddings

embeddings_model_name = "sentence-transformers/all-MiniLM-L6-v2"
embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name)

# Create the vector store
vectorstore = Chroma(embedding_function=embeddings)

data = []
for i in range(len(dataset)):
    if type(dataset[i]["target_text"] == str):
        filtered_ds = Document(page_content= str(dataset[i]["target_text"]), metadata={"verdict": dataset[i]["verdict"]})
        data.append(filtered_ds)

vectorstore.add_documents(data)

  embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name)


['cb373141-1bcc-4d1c-a379-6d95d69893b6',
 'b6613334-3b18-4faf-afe3-2e3e44e7fecf',
 '925b4d74-81bc-413e-8751-42b36655c28a',
 '6af5b928-bb5a-4086-839d-9c3680388da6',
 'a4d18a45-b327-4da2-8294-36164706ac48',
 'cbc31e96-5c2f-4e50-8fd7-681b970890f6',
 'b3acec79-a41c-4b24-a7fe-cd547943afe7',
 'ca4bf3a0-9a78-4655-b890-0661ec04f2fa',
 '612b8581-815c-44cd-ba46-b4d4f9db13b6',
 'ab9fb0c2-205c-47ce-8184-539fb2388298',
 '1827b281-648a-4706-96fb-edc25cb32e9a',
 '7ad359b7-5631-4308-9972-f68207ea0451',
 '5bc1d5d0-2ba1-4de9-a6e1-06f51cd3958f',
 'c82fa970-7f37-404f-a48d-f2c3126f5570',
 'dd89311b-3967-481a-9fb4-72eeb490abed',
 '37e47eec-895f-46f6-bb31-dacd8acf9695',
 '0a2d812c-9635-4095-a519-a96d223f03d0',
 '6cdf9810-5750-45b9-ae51-3dd78fcffb45',
 '95940e44-cae7-4736-a182-780c600d7144',
 '185d87ec-4e6d-44ba-83c0-88b17768ab27',
 'cf3c9592-d1b8-4594-b736-6a914baee3a4',
 '292888ac-5e51-4dc2-80d4-1b52c9151298',
 '5f559a61-e2c9-4483-9f6a-d95e3d449196',
 'e2c66461-03fb-4f79-a5c8-d0ad1a770c36',
 '1a2f7b2f-de61-

In [32]:
similarity_count = 3
def search_similar_results(query):

    results = vectorstore.similarity_search(query, k=similarity_count)
    total = 0
    for doc in results:
        print(f"Metadata: {doc.metadata}")
        
        if doc.metadata["verdict"] == "win":
            total +=1

    # Give the average verdict
    return total / similarity_count

In [41]:
query = "At no point did Plaintiff consent to receiving such messages from Defendant. Furthermore, Plaintiff had not engaged in business with Defendant in the 18 months prior, nor made any inquiries about their products in the preceding three months. Plaintiff, therefore, asserts that he and other members of the class have been subjected to unlawful advertising under 47 U.S.C. § 227. Plaintiff now seeks certification of the Class under Rule 23(b)(2) and 23(b)(3), defined as all persons who received one or more unsolicited SMS messages from the Defendant. The Defendant, along with its employees and those who consented to receive such messages or had an established business relationship with Defendant, are excluded from the class."

In [None]:
search_similar_results(query)

In [3]:
from transformers import BertForSequenceClassification
from transformers import BertTokenizer

BERT_MODEL = "textattack/bert-base-uncased-yelp-polarity"

tokenizer = BertTokenizer.from_pretrained(BERT_MODEL)


model = BertForSequenceClassification.from_pretrained(BERT_MODEL, num_labels = 2)

In [36]:
from openai import OpenAI
import requests
import json

YOUR_API_KEY = "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"


# Get the summary input from the lawyer
def summarise(summarising_input):

    # Construct the API payload
    payload = {
        "model": "llama-3.1-sonar-small-128k-online",
        "messages": [
            {"role": "system", "content": "You are an artificial intelligence assistant and you need to engage in a helpful, detailed, polite conversation with a lawyer."},
            {"role": "user", "content": f"Your task is to generate a short summary of a law case in at most 150 words taking an unbiased view taking account points in favour and against. Do not include any indicator tokens that indicate this is a summary. Put it as a single paragraph{summarising_input}"}
        ]
    }

    # Define the headers for the request
    headers = {
        "Authorization": f"Bearer {YOUR_API_KEY}",
        "accept": "application/json",
        "content-type": "application/json"
    }

    # Make the request to the API
    response = requests.post("https://api.perplexity.ai/chat/completions", json=payload, headers=headers)

    # Check if the request was successful
    if response.status_code == 200:
        data = response.json()

        # Extract the content directly if choices are present
        if 'choices' in data and data['choices']:
            content = data['choices'][0]['message']['content']  # Directly access the first choice
            
            # Print the formatted content
            return (content.strip())

    else:
        # Handle request errors
        print(f"Request failed with status code {response.status_code}: {response.text}")


In [37]:
def tokenize_text(text):
    create_summary = summarise(text)
    inputs = tokenizer(create_summary, return_tensors="pt")

    return inputs

In [42]:
tokenised = tokenize_text(query)

In [43]:
logits = model(**tokenised)

In [44]:
torch.softmax((logits.logits), -1)

tensor([[0.4910, 0.5090]], grad_fn=<SoftmaxBackward0>)

In [15]:
from torch.utils.data import Dataset, DataLoader

class Legal_Dataset(Dataset):
    def __init__(self):
        self.data = dataset

    def __len__(self):
        return 2700
    
    def __getitem__(self, idx):
        text = dataset[idx]["target_text"]
        verdict = dataset[idx]["verdict"]

        if verdict == "win":
            res = 1
        else:
            res = 0
        
        text = tokenize_text(text)
        return(text, res)

In [23]:
ds = Legal_Dataset()
train_dataloader = DataLoader(ds, batch_size=1, shuffle=True)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-6)
loss_fn = torch.nn.CrossEntropyLoss()

In [25]:
losses = []
loss_accumilation = None
loss_count = 0

for i, data in enumerate(train_dataloader):
    loss_count += 1

    inputs, labels = data
    optimizer.zero_grad()
    
    outputs = model(inputs["input_ids"][0]).logits

    loss = loss_fn(outputs, labels)
    if loss_accumilation == None:
        loss_accumilation = loss
    else:
        loss_accumilation += loss

    if loss_count % 25 == 0:
        loss_accumilation.backward()
        optimizer.step()

        print(i)
        print(loss_accumilation)

        losses.append(loss_accumilation)
        loss_accumilation = 0

    print(i)

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
tensor(17.5812, grad_fn=<AddBackward0>)
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
tensor(17.5332, grad_fn=<AddBackward0>)
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
tensor(17.6077, grad_fn=<AddBackward0>)
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
tensor(17.6866, grad_fn=<AddBackward0>)
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
tensor(16.6527, grad_fn=<AddBackward0>)
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
tensor(17.8780, grad_fn=<AddBackward0>)
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
tensor(17.1653, grad_fn=<AddBackward0>)
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
tensor

In [None]:
torch.save(model.state_dict(), "model.pth")